github.com/searKing/golang/go@v1.2.117/net/resolver/dns/dns_resolver.go (about)

     1  // Copyright 2021 The searKing Author. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package dns
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"net"
    12  	"strconv"
    13  	"sync"
    14  	"time"
    15  
    16  	rand_ "github.com/searKing/golang/go/math/rand"
    17  	"github.com/searKing/golang/go/net/resolver"
    18  	time_ "github.com/searKing/golang/go/time"
    19  )
    20  
    21  // EnableSRVLookups controls whether the DNS resolver attempts to fetch
    22  // addresses from SRV records.  Must not be changed after init time.
    23  var EnableSRVLookups = false
    24  
    25  // Globals to stub out in tests.
    26  var newTimer = time.NewTimer
    27  
    28  func init() {
    29  	resolver.Register(NewBuilder())
    30  }
    31  
    32  const (
    33  	defaultPort       = "443"
    34  	defaultDNSSvrPort = "53"
    35  )
    36  
    37  var (
    38  	errMissingAddr = errors.New("dns resolver: missing address")
    39  
    40  	// Addresses ending with a colon that is supposed to be the separator
    41  	// between host and port is not allowed.  E.g. "::" is a valid address as
    42  	// it is an IPv6 address (host only) and "[::]:" is invalid as it ends with
    43  	// a colon as the host and port separator
    44  	errEndsWithColon = errors.New("dns resolver: missing port after port-separator colon")
    45  )
    46  
    47  var (
    48  	defaultResolver netResolver = net.DefaultResolver
    49  	// To prevent excessive re-resolution, we enforce a rate limit on DNS
    50  	// resolution requests.
    51  	minDNSResRate = 30 * time.Second
    52  )
    53  
    54  var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
    55  	return func(ctx context.Context, network, address string) (net.Conn, error) {
    56  		var dialer net.Dialer
    57  		return dialer.DialContext(ctx, network, authority)
    58  	}
    59  }
    60  
    61  var customAuthorityResolver = func(authority string) (netResolver, error) {
    62  	host, port, err := parseTarget(authority, defaultDNSSvrPort)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	authorityWithPort := net.JoinHostPort(host, port)
    68  
    69  	return &net.Resolver{
    70  		PreferGo: true,
    71  		Dial:     customAuthorityDialler(authorityWithPort),
    72  	}, nil
    73  }
    74  
    75  // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
    76  func NewBuilder() resolver.Builder {
    77  	return &dnsBuilder{}
    78  }
    79  
    80  type dnsBuilder struct{}
    81  
    82  // Build creates and starts a DNS resolver that watches the name resolution of the target.
    83  func (b *dnsBuilder) Build(ctx context.Context, target resolver.Target, opts ...resolver.BuildOption) (resolver.Resolver, error) {
    84  	var opt resolver.Build
    85  	opt.ApplyOptions(opts...)
    86  	host, port, err := parseTarget(target.Endpoint, defaultPort)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  	cc := opt.ClientConn
    91  
    92  	// IP address.
    93  	if ipAddr, ok := formatIP(host); ok {
    94  		addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
    95  		if cc != nil {
    96  			_ = cc.UpdateState(resolver.State{Addresses: addr})
    97  		}
    98  		return deadResolver{
    99  			addrs: addr,
   100  		}, nil
   101  	}
   102  
   103  	// DNS address (non-IP).
   104  	ctx, cancel := context.WithCancel(context.Background())
   105  	d := &dnsResolver{
   106  		host:   host,
   107  		port:   port,
   108  		ctx:    ctx,
   109  		cancel: cancel,
   110  		cc:     cc,
   111  		rn:     make(chan struct{}, 1),
   112  	}
   113  
   114  	if target.Authority == "" {
   115  		d.resolver = defaultResolver
   116  	} else {
   117  		d.resolver, err = customAuthorityResolver(target.Authority)
   118  		if err != nil {
   119  			return nil, err
   120  		}
   121  	}
   122  
   123  	d.wg.Add(1)
   124  	go d.watcher()
   125  	return d, nil
   126  }
   127  
   128  // Scheme returns the naming scheme of this resolver builder, which is "dns".
   129  func (b *dnsBuilder) Scheme() string {
   130  	return "dns"
   131  }
   132  
   133  type netResolver interface {
   134  	LookupHost(ctx context.Context, host string) (addrs []string, err error)
   135  	LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
   136  	LookupTXT(ctx context.Context, name string) (txts []string, err error)
   137  }
   138  
   139  // deadResolver is a resolver that does nothing.
   140  type deadResolver struct {
   141  	picker resolver.Picker
   142  	addrs  []resolver.Address
   143  }
   144  
   145  func (d deadResolver) ResolveOneAddr(ctx context.Context, opts ...resolver.ResolveOneAddrOption) (resolver.Address, error) {
   146  	if len(d.addrs) == 0 {
   147  		return resolver.Address{}, fmt.Errorf("resolve target, but no addr")
   148  	}
   149  	return d.addrs[rand_.Intn(len(d.addrs))], nil
   150  }
   151  func (d deadResolver) ResolveAddr(ctx context.Context, opts ...resolver.ResolveAddrOption) ([]resolver.Address, error) {
   152  	return d.addrs, nil
   153  }
   154  func (deadResolver) ResolveNow(ctx context.Context, opts ...resolver.ResolveNowOption) {}
   155  
   156  func (deadResolver) Close() {}
   157  
   158  // dnsResolver watches for the name resolution update for a non-IP target.
   159  type dnsResolver struct {
   160  	host     string
   161  	port     string
   162  	resolver netResolver
   163  	ctx      context.Context
   164  	cancel   context.CancelFunc
   165  	cc       resolver.ClientConn
   166  	// rn channel is used by ResolveNow() to force an immediate resolution of the target.
   167  	rn chan struct{}
   168  	// wg is used to enforce Close() to return after the watcher() goroutine has finished.
   169  	// Otherwise, data race will be possible. [Race Example] in dns_resolver_test we
   170  	// replace the real lookup functions with mocked ones to facilitate testing.
   171  	// If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes
   172  	// will warns lookup (READ the lookup function pointers) inside watcher() goroutine
   173  	// has data race with replaceNetFunc (WRITE the lookup function pointers).
   174  	wg sync.WaitGroup
   175  }
   176  
   177  func (d *dnsResolver) ResolveOneAddr(ctx context.Context, opts ...resolver.ResolveOneAddrOption) (resolver.Address, error) {
   178  	d.ResolveNow(ctx)
   179  	addrs, err := d.lookupHost()
   180  	if err != nil {
   181  		return resolver.Address{}, err
   182  	}
   183  	if len(addrs) == 0 {
   184  		return resolver.Address{}, fmt.Errorf("resolve target, but no addr")
   185  	}
   186  	return addrs[rand_.Intn(len(addrs))], nil
   187  }
   188  
   189  func (d *dnsResolver) ResolveAddr(ctx context.Context, opts ...resolver.ResolveAddrOption) ([]resolver.Address, error) {
   190  	d.ResolveNow(ctx)
   191  	return d.lookupHost()
   192  }
   193  
   194  // ResolveNow invoke an immediate resolution of the target that this dnsResolver watches.
   195  func (d *dnsResolver) ResolveNow(ctx context.Context, opts ...resolver.ResolveNowOption) {
   196  	select {
   197  	case d.rn <- struct{}{}:
   198  	default:
   199  	}
   200  }
   201  
   202  // Close closes the dnsResolver.
   203  func (d *dnsResolver) Close() {
   204  	d.cancel()
   205  	d.wg.Wait()
   206  }
   207  
   208  func (d *dnsResolver) watcher() {
   209  	defer d.wg.Done()
   210  
   211  	backoff := time_.NewGrpcExponentialBackOff()
   212  	for {
   213  		addrs, err := d.lookupHost()
   214  		if d.cc != nil {
   215  			if err != nil {
   216  				// Report error to the underlying grpc.ClientConn.
   217  				d.cc.ReportError(err)
   218  			} else {
   219  				err = d.cc.UpdateState(resolver.State{Addresses: addrs})
   220  			}
   221  		}
   222  
   223  		var timer *time.Timer
   224  		if err == nil {
   225  			// Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least
   226  			// to prevent constantly re-resolving.
   227  			backoff.Reset()
   228  			timer = newTimer(minDNSResRate)
   229  			select {
   230  			case <-d.ctx.Done():
   231  				timer.Stop()
   232  				return
   233  			case <-d.rn:
   234  			}
   235  		} else {
   236  			// Poll on an error found in DNS Resolver or an error received from ClientConn.
   237  			bc, _ := backoff.NextBackOff()
   238  			timer = newTimer(bc)
   239  		}
   240  		select {
   241  		case <-d.ctx.Done():
   242  			timer.Stop()
   243  			return
   244  		case <-timer.C:
   245  		}
   246  	}
   247  }
   248  
   249  func (d *dnsResolver) lookupSRV(service, proto string) ([]string, error) {
   250  	if !EnableSRVLookups {
   251  		return nil, nil
   252  	}
   253  	var newAddrs []string
   254  	_, srvs, err := d.resolver.LookupSRV(d.ctx, service, proto, d.host)
   255  	if err != nil {
   256  		err = handleDNSError(err, "SRV") // may become nil
   257  		return nil, err
   258  	}
   259  	for _, s := range srvs {
   260  		lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
   261  		if err != nil {
   262  			err = handleDNSError(err, "A") // may become nil
   263  			if err == nil {
   264  				// If there are other SRV records, look them up and ignore this
   265  				// one that does not exist.
   266  				continue
   267  			}
   268  			return nil, err
   269  		}
   270  		for _, a := range lbAddrs {
   271  			ip, ok := formatIP(a)
   272  			if !ok {
   273  				return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
   274  			}
   275  			addr := ip + ":" + strconv.Itoa(int(s.Port))
   276  			newAddrs = append(newAddrs, addr)
   277  		}
   278  	}
   279  	return newAddrs, nil
   280  }
   281  
   282  var filterError = func(err error) error {
   283  	if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
   284  		// Timeouts and temporary errors should be communicated to gRPC to
   285  		// attempt another DNS query (with backoff).  Other errors should be
   286  		// suppressed (they may represent the absence of a TXT record).
   287  		return nil
   288  	}
   289  	return err
   290  }
   291  
   292  func handleDNSError(err error, lookupType string) error {
   293  	err = filterError(err)
   294  	if err != nil {
   295  		err = fmt.Errorf("dns: %v record lookup error: %w", lookupType, err)
   296  		return err
   297  	}
   298  	return nil
   299  }
   300  
   301  func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
   302  	var newAddrs []resolver.Address
   303  	addrs, err := d.resolver.LookupHost(d.ctx, d.host)
   304  	if err != nil {
   305  		err = handleDNSError(err, "A")
   306  		return nil, err
   307  	}
   308  	for _, a := range addrs {
   309  		ip, ok := formatIP(a)
   310  		if !ok {
   311  			return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
   312  		}
   313  		addr := ip + ":" + d.port
   314  		newAddrs = append(newAddrs, resolver.Address{Addr: addr})
   315  	}
   316  	return newAddrs, nil
   317  }
   318  
   319  // formatIP returns ok = false if addr is not a valid textual representation of an IP address.
   320  // If addr is an IPv4 address, return the addr and ok = true.
   321  // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true.
   322  func formatIP(addr string) (addrIP string, ok bool) {
   323  	ip := net.ParseIP(addr)
   324  	if ip == nil {
   325  		return "", false
   326  	}
   327  	if ip.To4() != nil {
   328  		return addr, true
   329  	}
   330  	return "[" + addr + "]", true
   331  }
   332  
   333  // parseTarget takes the user input target string and default port, returns formatted host and port info.
   334  // If target doesn't specify a port, set the port to be the defaultPort.
   335  // If target is in IPv6 format and host-name is enclosed in square brackets, brackets
   336  // are stripped when setting the host.
   337  // examples:
   338  // target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443"
   339  // target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80"
   340  // target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443"
   341  // target: ":80" defaultPort: "443" returns host: "localhost", port: "80"
   342  func parseTarget(target, defaultPort string) (host, port string, err error) {
   343  	if target == "" {
   344  		return "", "", errMissingAddr
   345  	}
   346  	if ip := net.ParseIP(target); ip != nil {
   347  		// target is an IPv4 or IPv6(without brackets) address
   348  		return target, defaultPort, nil
   349  	}
   350  	if host, port, err = net.SplitHostPort(target); err == nil {
   351  		if port == "" {
   352  			// If the port field is empty (target ends with colon), e.g. "[::1]:", this is an error.
   353  			return "", "", errEndsWithColon
   354  		}
   355  		// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
   356  		if host == "" {
   357  			// Keep consistent with net.Dial(): If the host is empty, as in ":80", the local system is assumed.
   358  			host = "localhost"
   359  		}
   360  		return host, port, nil
   361  	}
   362  	if host, port, err = net.SplitHostPort(target + ":" + defaultPort); err == nil {
   363  		// target doesn't have port
   364  		return host, port, nil
   365  	}
   366  	return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
   367  }