github.com/kaydxh/golang@v0.0.131/go/net/resolver/dns/dns_resolver.go (about)

     1  /*
     2   *Copyright (c) 2022, kaydxh
     3   *
     4   *Permission is hereby granted, free of charge, to any person obtaining a copy
     5   *of this software and associated documentation files (the "Software"), to deal
     6   *in the Software without restriction, including without limitation the rights
     7   *to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     8   *copies of the Software, and to permit persons to whom the Software is
     9   *furnished to do so, subject to the following conditions:
    10   *
    11   *The above copyright notice and this permission notice shall be included in all
    12   *copies or substantial portions of the Software.
    13   *
    14   *THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    15   *IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    16   *FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    17   *AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    18   *LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    19   *OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    20   *SOFTWARE.
    21   */
    22  package dns
    23  
    24  import (
    25  	"context"
    26  	"fmt"
    27  	"net"
    28  	"strconv"
    29  	"sync"
    30  	"time"
    31  
    32  	rand_ "github.com/kaydxh/golang/go/math/rand"
    33  	net_ "github.com/kaydxh/golang/go/net"
    34  	"github.com/kaydxh/golang/go/net/resolver"
    35  	time_ "github.com/kaydxh/golang/go/time"
    36  )
    37  
    38  // EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
    39  // addresses from SRV records.  Must not be changed after init time.
    40  var EnableSRVLookups = false
    41  
    42  // Globals to stub out in tests. TODO: Perhaps these two can be combined into a
    43  // single variable for testing the resolver?
    44  var (
    45  	newTimer = time.NewTimer
    46  )
    47  
    48  func init() {
    49  	resolver.Register(NewBuilder())
    50  }
    51  
    52  const (
    53  	defaultPort       = "443"
    54  	defaultDNSSvrPort = "53"
    55  )
    56  
    57  var (
    58  	defaultResolver netResolver = net.DefaultResolver
    59  	// To prevent excessive re-resolution, we enforce a rate limit on DNS
    60  	// resolution requests.
    61  	defaultSyncInterval = 30 * time.Second
    62  )
    63  
    64  var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
    65  	return func(ctx context.Context, network, address string) (net.Conn, error) {
    66  		var dialer net.Dialer
    67  		return dialer.DialContext(ctx, network, authority)
    68  	}
    69  }
    70  
    71  var customAuthorityResolver = func(authority string) (netResolver, error) {
    72  	host, port, err := net_.ParseTarget(authority, defaultDNSSvrPort)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	authorityWithPort := net.JoinHostPort(host, port)
    78  
    79  	return &net.Resolver{
    80  		PreferGo: true,
    81  		Dial:     customAuthorityDialler(authorityWithPort),
    82  	}, nil
    83  }
    84  
    85  // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
    86  func NewBuilder(opts ...dnsBuilderOption) resolver.Builder {
    87  	b := &dnsBuilder{}
    88  	b.ApplyOptions(opts...)
    89  	if b.opts.syncInterval == 0 {
    90  		b.opts.syncInterval = defaultSyncInterval
    91  	}
    92  
    93  	return b
    94  }
    95  
    96  type dnsBuilder struct {
    97  	opts struct {
    98  		syncInterval time.Duration
    99  	}
   100  }
   101  
   102  // Build creates and starts a DNS resolver that watches the name resolution of the target.
   103  func (b *dnsBuilder) Build(target resolver.Target, opts ...resolver.ResolverBuildOption) (resolver.Resolver, error) {
   104  	var opt resolver.ResolverBuildOptions
   105  	opt.ApplyOptions(opts...)
   106  	host, port, err := net_.ParseTarget(target.Endpoint, defaultPort)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	cc := opt.Cc
   111  
   112  	// IP address.
   113  	if ipAddr, ok := formatIP(host); ok {
   114  		addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
   115  		if cc != nil {
   116  			cc.UpdateState(resolver.State{Addresses: addr})
   117  		}
   118  		return deadResolver{
   119  			addrs: addr,
   120  		}, nil
   121  	}
   122  
   123  	// DNS address (non-IP).
   124  	ctx, cancel := context.WithCancel(context.Background())
   125  	d := &dnsResolver{
   126  		host:         host,
   127  		port:         port,
   128  		syncInterval: b.opts.syncInterval,
   129  		ctx:          ctx,
   130  		cancel:       cancel,
   131  		cc:           cc,
   132  		rn:           make(chan struct{}, 1),
   133  	}
   134  
   135  	if target.Authority == "" {
   136  		d.resolver = defaultResolver
   137  	} else {
   138  		d.resolver, err = customAuthorityResolver(target.Authority)
   139  		if err != nil {
   140  			return nil, err
   141  		}
   142  	}
   143  
   144  	d.wg.Add(1)
   145  	go d.watcher()
   146  	return d, nil
   147  }
   148  
   149  // Scheme returns the naming scheme of this resolver builder, which is "dns".
   150  func (b *dnsBuilder) Scheme() string {
   151  	return "dns"
   152  }
   153  
   154  type netResolver interface {
   155  	LookupHost(ctx context.Context, host string) (addrs []string, err error)
   156  	LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
   157  	LookupTXT(ctx context.Context, name string) (txts []string, err error)
   158  }
   159  
   160  // deadResolver is a resolver that does nothing.
   161  type deadResolver struct {
   162  	addrs []resolver.Address
   163  }
   164  
   165  func (d deadResolver) ResolveOne(opts ...resolver.ResolveOneOption) (resolver.Address, error) {
   166  	var opt resolver.ResolveOneOptions
   167  	opt.ApplyOptions(opts...)
   168  
   169  	addrs, err := d.ResolveAll(resolver.WithIPTypeForResolveAll(opt.IPType))
   170  	if err != nil {
   171  		return resolver.Address{}, err
   172  	}
   173  
   174  	switch opt.PickMode {
   175  	case resolver.Resolver_pick_mode_random:
   176  		return addrs[rand_.Intn(len(addrs))], nil
   177  	case resolver.Resolver_pick_mode_first:
   178  		return addrs[0], nil
   179  	default:
   180  		return addrs[rand_.Intn(len(addrs))], nil
   181  
   182  	}
   183  }
   184  
   185  func (d deadResolver) ResolveAll(opts ...resolver.ResolveAllOption) ([]resolver.Address, error) {
   186  	var opt resolver.ResolveAllOptions
   187  	opt.ApplyOptions(opts...)
   188  	if len(d.addrs) == 0 {
   189  		return nil, fmt.Errorf("resolve target's addresses are empty")
   190  	}
   191  
   192  	var pickAddrs []resolver.Address
   193  	if opt.IPType == resolver.Resolver_ip_type_all {
   194  		pickAddrs = d.addrs
   195  	} else {
   196  		for _, addr := range d.addrs {
   197  			v4 := (opt.IPType == resolver.Resolver_ip_type_v4)
   198  			ip, _, _ := net_.SplitHostIntPort(addr.Addr)
   199  			if net_.IsIPv4String(ip) {
   200  				if v4 {
   201  					pickAddrs = append(pickAddrs, addr)
   202  				}
   203  			} else {
   204  				//v6
   205  				if !v4 {
   206  					pickAddrs = append(pickAddrs, addr)
   207  				}
   208  			}
   209  		}
   210  	}
   211  	if len(pickAddrs) == 0 {
   212  		return nil, fmt.Errorf("resolve target's addresses type[%v] are empty", opt.IPType)
   213  	}
   214  	return pickAddrs, nil
   215  }
   216  
   217  func (deadResolver) ResolveNow(opts ...resolver.ResolveNowOption) {}
   218  
   219  func (deadResolver) Close() {}
   220  
   221  // dnsResolver watches for the name resolution update for a non-IP target.
   222  type dnsResolver struct {
   223  	host         string
   224  	port         string
   225  	resolver     netResolver
   226  	syncInterval time.Duration
   227  
   228  	ctx    context.Context
   229  	cancel context.CancelFunc
   230  	cc     resolver.ClientConn
   231  	// rn channel is used by ResolveNow() to force an immediate resolution of the target.
   232  	rn chan struct{}
   233  	// wg is used to enforce Close() to return after the watcher() goroutine has finished.
   234  	// Otherwise, data race will be possible. [Race Example] in dns_resolver_test we
   235  	// replace the real lookup functions with mocked ones to facilitate testing.
   236  	// If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes
   237  	// will warns lookup (READ the lookup function pointers) inside watcher() goroutine
   238  	// has data race with replaceNetFunc (WRITE the lookup function pointers).
   239  	wg sync.WaitGroup
   240  }
   241  
   242  func (d *dnsResolver) ResolveOne(opts ...resolver.ResolveOneOption) (resolver.Address, error) {
   243  	var opt resolver.ResolveOneOptions
   244  	opt.ApplyOptions(opts...)
   245  
   246  	addrs, err := d.ResolveAll(resolver.WithIPTypeForResolveAll(opt.IPType))
   247  	if err != nil {
   248  		return resolver.Address{}, err
   249  	}
   250  
   251  	switch opt.PickMode {
   252  	case resolver.Resolver_pick_mode_random:
   253  		return addrs[rand_.Intn(len(addrs))], nil
   254  	case resolver.Resolver_pick_mode_first:
   255  		return addrs[0], nil
   256  	default:
   257  		return addrs[rand_.Intn(len(addrs))], nil
   258  
   259  	}
   260  }
   261  
   262  func (d *dnsResolver) ResolveAll(opts ...resolver.ResolveAllOption) ([]resolver.Address, error) {
   263  	var opt resolver.ResolveAllOptions
   264  	opt.ApplyOptions(opts...)
   265  	d.ResolveNow()
   266  	addrs, err := d.lookupHost()
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  	if len(addrs) == 0 {
   271  		return nil, fmt.Errorf("resolve target's addresses are empty")
   272  	}
   273  
   274  	var pickAddrs []resolver.Address
   275  	if opt.IPType == resolver.Resolver_ip_type_all {
   276  		pickAddrs = addrs
   277  	} else {
   278  		for _, addr := range addrs {
   279  			v4 := (opt.IPType == resolver.Resolver_ip_type_v4)
   280  			ip, _, _ := net_.SplitHostIntPort(addr.Addr)
   281  			if net_.IsIPv4String(ip) {
   282  				if v4 {
   283  					pickAddrs = append(pickAddrs, addr)
   284  				}
   285  			} else {
   286  				//v6
   287  				if !v4 {
   288  					pickAddrs = append(pickAddrs, addr)
   289  				}
   290  			}
   291  		}
   292  	}
   293  	if len(pickAddrs) == 0 {
   294  		return nil, fmt.Errorf("resolve target's addresses type[%v] are empty", opt.IPType)
   295  	}
   296  	return pickAddrs, nil
   297  }
   298  
   299  // ResolveNow invoke an immediate resolution of the target that this dnsResolver watches.
   300  func (d *dnsResolver) ResolveNow(opts ...resolver.ResolveNowOption) {
   301  	select {
   302  	case d.rn <- struct{}{}:
   303  	default:
   304  	}
   305  }
   306  
   307  // Close closes the dnsResolver.
   308  func (d *dnsResolver) Close() {
   309  	d.cancel()
   310  	d.wg.Wait()
   311  }
   312  
   313  func (d *dnsResolver) watcher() {
   314  	defer d.wg.Done()
   315  
   316  	backoff := time_.NewExponentialBackOff()
   317  	for {
   318  		addrs, err := d.lookupHost()
   319  		if d.cc != nil {
   320  			if err != nil {
   321  				// Report error to the underlying grpc.ClientConn.
   322  				d.cc.ReportError(err)
   323  			} else {
   324  				err = d.cc.UpdateState(resolver.State{Addresses: addrs})
   325  			}
   326  		}
   327  
   328  		var timer *time.Timer
   329  		if err == nil {
   330  			// Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least
   331  			// to prevent constantly re-resolving.
   332  			backoff.Reset()
   333  			timer = newTimer(d.syncInterval)
   334  			select {
   335  			case <-d.ctx.Done():
   336  				timer.Stop()
   337  				return
   338  			case <-d.rn:
   339  			}
   340  		} else {
   341  			// Poll on an error found in DNS Resolver or an error received from ClientConn.
   342  			actualInterval, _ := backoff.NextBackOff()
   343  			timer = newTimer(actualInterval)
   344  		}
   345  		select {
   346  		case <-d.ctx.Done():
   347  			timer.Stop()
   348  			return
   349  		case <-timer.C:
   350  		}
   351  	}
   352  }
   353  
   354  func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) {
   355  	if !EnableSRVLookups {
   356  		return nil, nil
   357  	}
   358  	var newAddrs []resolver.Address
   359  	_, srvs, err := d.resolver.LookupSRV(d.ctx, "grpclb", "tcp", d.host)
   360  	if err != nil {
   361  		err = handleDNSError(err, "SRV") // may become nil
   362  		return nil, err
   363  	}
   364  	for _, s := range srvs {
   365  		lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
   366  		if err != nil {
   367  			err = handleDNSError(err, "A") // may become nil
   368  			if err == nil {
   369  				// If there are other SRV records, look them up and ignore this
   370  				// one that does not exist.
   371  				continue
   372  			}
   373  			return nil, err
   374  		}
   375  		for _, a := range lbAddrs {
   376  			ip, ok := formatIP(a)
   377  			if !ok {
   378  				return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
   379  			}
   380  			addr := ip + ":" + strconv.Itoa(int(s.Port))
   381  			newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
   382  		}
   383  	}
   384  	return newAddrs, nil
   385  }
   386  
   387  func handleDNSError(err error, lookupType string) error {
   388  	if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
   389  		// Timeouts and temporary errors should be communicated to gRPC to
   390  		// attempt another DNS query (with backoff).  Other errors should be
   391  		// suppressed (they may represent the absence of a TXT record).
   392  		return nil
   393  	}
   394  	if err != nil {
   395  		err = fmt.Errorf("dns: %v record lookup error: %v", lookupType, err)
   396  	}
   397  	return err
   398  }
   399  
   400  func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
   401  	addrs, err := d.resolver.LookupHost(d.ctx, d.host)
   402  	if err != nil {
   403  		err = handleDNSError(err, "A")
   404  		return nil, err
   405  	}
   406  	newAddrs := make([]resolver.Address, 0, len(addrs))
   407  	for _, a := range addrs {
   408  		ip, ok := formatIP(a)
   409  		if !ok {
   410  			return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
   411  		}
   412  		addr := ip + ":" + d.port
   413  		newAddrs = append(newAddrs, resolver.Address{Addr: addr})
   414  	}
   415  	return newAddrs, nil
   416  }
   417  
   418  // formatIP returns ok = false if addr is not a valid textual representation of an IP address.
   419  // If addr is an IPv4 address, return the addr and ok = true.
   420  // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true.
   421  func formatIP(addr string) (addrIP string, ok bool) {
   422  	ip := net.ParseIP(addr)
   423  	if ip == nil {
   424  		return "", false
   425  	}
   426  	if ip.To4() != nil {
   427  		return addr, true
   428  	}
   429  	return "[" + addr + "]", true
   430  }