github.com/mailgun/holster/v4@v4.20.0/discovery/grpc_srv_resolver.go (about)

     1  package discovery
     2  
     3  // Based on grpc-go/internal/resolver/dns
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"strconv"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/mailgun/holster/v4/cancel"
    15  	"github.com/mailgun/holster/v4/retry"
    16  	"github.com/sirupsen/logrus"
    17  	"google.golang.org/grpc/resolver"
    18  )
    19  
    20  func init() {
    21  	GRPCSrvDefaultLogger = logrus.StandardLogger()
    22  }
    23  
    24  var (
    25  	ErrMissingAddr         = errors.New("missing address")
    26  	ErrEndsWithColon       = errors.New("missing port after port-separator colon")
    27  	ErrIPAddressNotAllowed = errors.New("ip address is not allowed; must be a dns service name")
    28  	GRPCSrvDefaultPort     = "443"
    29  	GRPCSrvDefaultLogger   logrus.FieldLogger
    30  
    31  	// GRPCSrvLogAddresses if true then GRPC will log the list of addresses received when making an SRV
    32  	GRPCSrvLogAddresses = false
    33  )
    34  
    35  // NewGRPCSRVBuilder creates a srvResolverBuilder which is used to factory SRV-DNS resolvers.
    36  func NewGRPCSRVBuilder() resolver.Builder {
    37  	return &srvResolverBuilder{}
    38  }
    39  
    40  type srvResolverBuilder struct{}
    41  
    42  func (*srvResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) {
    43  	host, port, err := parseTarget(target.Endpoint(), GRPCSrvDefaultPort)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	// IP address.
    49  	if _, ok := formatIP(host); ok {
    50  		return nil, ErrIPAddressNotAllowed
    51  	}
    52  
    53  	d := &srvResolver{
    54  		ctx:  cancel.New(context.Background()),
    55  		rn:   make(chan struct{}, 1),
    56  		host: host,
    57  		port: port,
    58  		cc:   cc,
    59  	}
    60  
    61  	d.wg.Add(1)
    62  	go d.watcher()
    63  	return d, nil
    64  }
    65  
    66  func (*srvResolverBuilder) Scheme() string { return "dns-srv" }
    67  
    68  // srvResolver watches for the name resolution update for a non-IP target.
    69  type srvResolver struct {
    70  	host  string
    71  	port  string
    72  	ctx   cancel.Context
    73  	cc    resolver.ClientConn
    74  	state resolver.State
    75  	// rn channel is used by ResolveNow() to force an immediate resolution of the target.
    76  	rn chan struct{}
    77  	// wg is used to enforce Close() to return after the watcher() goroutine has finished.
    78  	// Otherwise, data race will be possible. [Race Example] in dns_resolver_test we
    79  	// replace the real lookup functions with mocked ones to facilitate testing.
    80  	// If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes
    81  	// will warns lookup (READ the lookup function pointers) inside watcher() goroutine
    82  	// has data race with replaceNetFunc (WRITE the lookup function pointers).
    83  	wg sync.WaitGroup
    84  }
    85  
    86  // ResolveNow invoke an immediate resolution of the target that this srvResolver watches.
    87  func (d *srvResolver) ResolveNow(resolver.ResolveNowOptions) {
    88  	select {
    89  	case d.rn <- struct{}{}:
    90  	default:
    91  	}
    92  }
    93  
    94  // Close closes the srvResolver.
    95  func (d *srvResolver) Close() {
    96  	d.ctx.Cancel()
    97  	d.wg.Wait()
    98  }
    99  
   100  func (d *srvResolver) watcher() {
   101  	defer d.wg.Done()
   102  
   103  	ticker := time.NewTicker(time.Minute * 60)
   104  	backOff := &retry.ExponentialBackOff{
   105  		Min:    time.Second,
   106  		Max:    120 * time.Second,
   107  		Factor: 1.6,
   108  	}
   109  	var lastSuccess time.Time
   110  
   111  	for {
   112  		// Avoid constantly re-resolving if multiple connections make ResolveNow() calls
   113  		if time.Since(lastSuccess) < time.Second*30 {
   114  			goto wait
   115  		}
   116  
   117  		if err := d.lookupSRV(); err != nil {
   118  			d.cc.ReportError(err)
   119  			next := backOff.NextIteration()
   120  			GRPCSrvDefaultLogger.WithError(err).WithField("retry-after", next).
   121  				Error("dns lookup failed; retrying...")
   122  
   123  			timer := time.NewTimer(next)
   124  			select {
   125  			case <-d.ctx.Done():
   126  				timer.Stop()
   127  				return
   128  			case <-timer.C:
   129  			}
   130  			continue
   131  		}
   132  		lastSuccess = time.Now()
   133  	wait:
   134  		backOff.Reset()
   135  
   136  		select {
   137  		case <-d.ctx.Done():
   138  			ticker.Stop()
   139  			return
   140  		case <-ticker.C:
   141  		case <-d.rn:
   142  		}
   143  	}
   144  }
   145  
   146  func (d *srvResolver) lookupSRV() error {
   147  	var result []resolver.Address
   148  
   149  	// TODO(thrawn01): At some point in the future we might parse the Target and determine
   150  	//  if the Target name is in the RFC 2782 form of `_<service>._tcp[.service][.<datacenter>].<domain>`
   151  	//  then fill out the service and proto fields in LookupSRV()
   152  
   153  	_, srvs, err := net.DefaultResolver.LookupSRV(d.ctx, "", "", d.host)
   154  	if err != nil {
   155  		return fmt.Errorf("SRV record lookup err: %w", err)
   156  	}
   157  	for _, s := range srvs {
   158  		resolved, err := net.DefaultResolver.LookupHost(d.ctx, s.Target)
   159  		if err != nil {
   160  			GRPCSrvDefaultLogger.WithError(err).WithField("target", s.Target).
   161  				Error("error resolving 'A' records for SRV entry")
   162  			continue
   163  		}
   164  
   165  		var addresses []resolver.Address
   166  		for _, a := range resolved {
   167  			ip, ok := formatIP(a)
   168  			if !ok {
   169  				GRPCSrvDefaultLogger.WithField("ip", ip).
   170  					Error("error parsing 'A' record for SRV entries; is not a valid ip address")
   171  				continue
   172  			}
   173  			addresses = append(addresses, resolver.Address{Addr: ip + ":" + strconv.Itoa(int(s.Port)), ServerName: s.Target})
   174  		}
   175  
   176  		// If our current state is empty, then immediately update state before looking up the remaining SRV records.
   177  		// Looking up a lot of hosts could take a lot of time, and we want to connect as quickly as possible.
   178  		// During testing, a DNS lookup on all service nodes for `ratelimits` took 5+ seconds, which caused the
   179  		// GRPC calls to timeout.
   180  		if len(d.state.Addresses) == 0 {
   181  			if err := d.cc.UpdateState(resolver.State{Addresses: addresses}); err != nil {
   182  				GRPCSrvDefaultLogger.WithError(err).Error("UpdateState() call returned an error")
   183  			}
   184  			d.state.Addresses = addresses
   185  		}
   186  		result = append(result, addresses...)
   187  	}
   188  
   189  	if len(result) == 0 {
   190  		return fmt.Errorf("SRV record for '%s' contained no valid domain names", d.host)
   191  	}
   192  
   193  	d.state.Addresses = result
   194  	if GRPCSrvLogAddresses {
   195  		var addresses []string
   196  		for _, a := range result {
   197  			addresses = append(addresses, a.Addr)
   198  		}
   199  		GRPCSrvDefaultLogger.WithField("addresses", addresses).Info("dns-srv: address list updated")
   200  	}
   201  	return d.cc.UpdateState(d.state)
   202  }
   203  
   204  // formatIP returns ok = false if addr is not a valid textual representation of an IP address.
   205  // If addr is an IPv4 address, return the addr and ok = true.
   206  // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true.
   207  func formatIP(addr string) (addrIP string, ok bool) {
   208  	ip := net.ParseIP(addr)
   209  	if ip == nil {
   210  		return "", false
   211  	}
   212  	if ip.To4() != nil {
   213  		return addr, true
   214  	}
   215  	return "[" + addr + "]", true
   216  }
   217  
   218  // parseTarget takes the user input target string and default port, returns formatted host and port info.
   219  // If target doesn't specify a port, set the port to be the defaultPort.
   220  // If target is in IPv6 format and host-name is enclosed in square brackets, brackets
   221  // are stripped when setting the host.
   222  // examples:
   223  // target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443"
   224  // target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80"
   225  // target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443"
   226  // target: ":80" defaultPort: "443" returns host: "localhost", port: "80"
   227  func parseTarget(target, defaultPort string) (host, port string, err error) {
   228  	if target == "" {
   229  		return "", "", ErrMissingAddr
   230  	}
   231  	if ip := net.ParseIP(target); ip != nil {
   232  		// target is an IPv4 or IPv6(without brackets) address
   233  		return target, defaultPort, nil
   234  	}
   235  	if host, port, err = net.SplitHostPort(target); err == nil {
   236  		if port == "" {
   237  			// If the port field is empty (target ends with colon), e.g. "[::1]:", this is an error.
   238  			return "", "", ErrEndsWithColon
   239  		}
   240  		// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
   241  		if host == "" {
   242  			// Keep consistent with net.Dial(): If the host is empty, as in ":80", the local system is assumed.
   243  			host = "localhost"
   244  		}
   245  		return host, port, nil
   246  	}
   247  	if host, port, err = net.SplitHostPort(target + ":" + defaultPort); err == nil {
   248  		// target doesn't have port
   249  		return host, port, nil
   250  	}
   251  	return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
   252  }