github.com/yandex/pandora@v0.5.32/lib/netutil/dial.go (about)

     1  package netutil
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/pkg/errors"
    10  )
    11  
    12  //go:generate mockery --name=Dialer --case=underscore --outpkg=netmock
    13  
    14  type Dialer interface {
    15  	DialContext(ctx context.Context, net, addr string) (net.Conn, error)
    16  }
    17  
    18  var _ Dialer = &net.Dialer{}
    19  
    20  type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error)
    21  
    22  func (f DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
    23  	return f(ctx, network, address)
    24  }
    25  
    26  // NewDNSCachingDialer returns dialer with primitive DNS caching logic
    27  // that remembers remote address on first try, and use it in future.
    28  func NewDNSCachingDialer(dialer Dialer, cache DNSCache) DialerFunc {
    29  	return func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
    30  		resolved, ok := cache.Get(addr)
    31  		if ok {
    32  			return dialer.DialContext(ctx, network, resolved)
    33  		}
    34  		conn, err = dialer.DialContext(ctx, network, addr)
    35  		if err != nil {
    36  			return
    37  		}
    38  		remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
    39  		_, port, err := net.SplitHostPort(addr)
    40  		if err != nil {
    41  			_ = conn.Close()
    42  			return nil, errors.Wrap(err, "invalid address, but successful dial - should not happen")
    43  		}
    44  		cache.Add(addr, net.JoinHostPort(remoteAddr.IP.String(), port))
    45  		return
    46  	}
    47  }
    48  
    49  var DefaultDNSCache = &SimpleDNSCache{}
    50  
    51  // LookupReachable tries to resolve addr via connecting to it.
    52  // This method has much more overhead, but get guaranteed reachable resolved addr.
    53  // Example: host is resolved to IPv4 and IPv6, but IPv4 is not working on machine.
    54  // LookupReachable will return IPv6 in that case.
    55  func LookupReachable(addr string, timeout time.Duration) (string, error) {
    56  	d := net.Dialer{DualStack: true, Timeout: timeout}
    57  	conn, err := d.Dial("tcp", addr)
    58  	if err != nil {
    59  		return "", err
    60  	}
    61  	defer conn.Close()
    62  	_, port, err := net.SplitHostPort(addr)
    63  	if err != nil {
    64  		return "", err
    65  	}
    66  	remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
    67  	return net.JoinHostPort(remoteAddr.IP.String(), port), nil
    68  }
    69  
    70  // WarmDNSCache tries connect to addr, and adds conn remote ip + addr port to cache.
    71  func WarmDNSCache(c DNSCache, addr string) error {
    72  	var d net.Dialer
    73  	conn, err := NewDNSCachingDialer(&d, c).DialContext(context.Background(), "tcp", addr)
    74  	if err != nil {
    75  		return err
    76  	}
    77  	_ = conn.Close()
    78  	return nil
    79  }
    80  
    81  //go:generate mockery --name=DNSCache --case=underscore --outpkg=netmock
    82  
    83  type DNSCache interface {
    84  	Get(addr string) (string, bool)
    85  	Add(addr, resolved string)
    86  }
    87  
    88  type SimpleDNSCache struct {
    89  	rw         sync.RWMutex
    90  	hostToAddr map[string]string
    91  }
    92  
    93  func (c *SimpleDNSCache) Get(addr string) (resolved string, ok bool) {
    94  	c.rw.RLock()
    95  	if c.hostToAddr == nil {
    96  		c.rw.RUnlock()
    97  		return
    98  	}
    99  	resolved, ok = c.hostToAddr[addr]
   100  	c.rw.RUnlock()
   101  	return
   102  }
   103  
   104  func (c *SimpleDNSCache) Add(addr, resolved string) {
   105  	c.rw.Lock()
   106  	if c.hostToAddr == nil {
   107  		c.hostToAddr = make(map[string]string)
   108  	}
   109  	c.hostToAddr[addr] = resolved
   110  	c.rw.Unlock()
   111  }