github.com/yandex/pandora@v0.5.32/lib/netutil/dial.go (about) 1 package netutil 2 3 import ( 4 "context" 5 "net" 6 "sync" 7 "time" 8 9 "github.com/pkg/errors" 10 ) 11 12 //go:generate mockery --name=Dialer --case=underscore --outpkg=netmock 13 14 type Dialer interface { 15 DialContext(ctx context.Context, net, addr string) (net.Conn, error) 16 } 17 18 var _ Dialer = &net.Dialer{} 19 20 type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error) 21 22 func (f DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 23 return f(ctx, network, address) 24 } 25 26 // NewDNSCachingDialer returns dialer with primitive DNS caching logic 27 // that remembers remote address on first try, and use it in future. 28 func NewDNSCachingDialer(dialer Dialer, cache DNSCache) DialerFunc { 29 return func(ctx context.Context, network, addr string) (conn net.Conn, err error) { 30 resolved, ok := cache.Get(addr) 31 if ok { 32 return dialer.DialContext(ctx, network, resolved) 33 } 34 conn, err = dialer.DialContext(ctx, network, addr) 35 if err != nil { 36 return 37 } 38 remoteAddr := conn.RemoteAddr().(*net.TCPAddr) 39 _, port, err := net.SplitHostPort(addr) 40 if err != nil { 41 _ = conn.Close() 42 return nil, errors.Wrap(err, "invalid address, but successful dial - should not happen") 43 } 44 cache.Add(addr, net.JoinHostPort(remoteAddr.IP.String(), port)) 45 return 46 } 47 } 48 49 var DefaultDNSCache = &SimpleDNSCache{} 50 51 // LookupReachable tries to resolve addr via connecting to it. 52 // This method has much more overhead, but get guaranteed reachable resolved addr. 53 // Example: host is resolved to IPv4 and IPv6, but IPv4 is not working on machine. 54 // LookupReachable will return IPv6 in that case. 55 func LookupReachable(addr string, timeout time.Duration) (string, error) { 56 d := net.Dialer{DualStack: true, Timeout: timeout} 57 conn, err := d.Dial("tcp", addr) 58 if err != nil { 59 return "", err 60 } 61 defer conn.Close() 62 _, port, err := net.SplitHostPort(addr) 63 if err != nil { 64 return "", err 65 } 66 remoteAddr := conn.RemoteAddr().(*net.TCPAddr) 67 return net.JoinHostPort(remoteAddr.IP.String(), port), nil 68 } 69 70 // WarmDNSCache tries connect to addr, and adds conn remote ip + addr port to cache. 71 func WarmDNSCache(c DNSCache, addr string) error { 72 var d net.Dialer 73 conn, err := NewDNSCachingDialer(&d, c).DialContext(context.Background(), "tcp", addr) 74 if err != nil { 75 return err 76 } 77 _ = conn.Close() 78 return nil 79 } 80 81 //go:generate mockery --name=DNSCache --case=underscore --outpkg=netmock 82 83 type DNSCache interface { 84 Get(addr string) (string, bool) 85 Add(addr, resolved string) 86 } 87 88 type SimpleDNSCache struct { 89 rw sync.RWMutex 90 hostToAddr map[string]string 91 } 92 93 func (c *SimpleDNSCache) Get(addr string) (resolved string, ok bool) { 94 c.rw.RLock() 95 if c.hostToAddr == nil { 96 c.rw.RUnlock() 97 return 98 } 99 resolved, ok = c.hostToAddr[addr] 100 c.rw.RUnlock() 101 return 102 } 103 104 func (c *SimpleDNSCache) Add(addr, resolved string) { 105 c.rw.Lock() 106 if c.hostToAddr == nil { 107 c.hostToAddr = make(map[string]string) 108 } 109 c.hostToAddr[addr] = resolved 110 c.rw.Unlock() 111 }