github.com/mailgun/holster/v4@v4.20.0/discovery/grpc_srv_resolver.go (about) 1 package discovery 2 3 // Based on grpc-go/internal/resolver/dns 4 5 import ( 6 "context" 7 "errors" 8 "fmt" 9 "net" 10 "strconv" 11 "sync" 12 "time" 13 14 "github.com/mailgun/holster/v4/cancel" 15 "github.com/mailgun/holster/v4/retry" 16 "github.com/sirupsen/logrus" 17 "google.golang.org/grpc/resolver" 18 ) 19 20 func init() { 21 GRPCSrvDefaultLogger = logrus.StandardLogger() 22 } 23 24 var ( 25 ErrMissingAddr = errors.New("missing address") 26 ErrEndsWithColon = errors.New("missing port after port-separator colon") 27 ErrIPAddressNotAllowed = errors.New("ip address is not allowed; must be a dns service name") 28 GRPCSrvDefaultPort = "443" 29 GRPCSrvDefaultLogger logrus.FieldLogger 30 31 // GRPCSrvLogAddresses if true then GRPC will log the list of addresses received when making an SRV 32 GRPCSrvLogAddresses = false 33 ) 34 35 // NewGRPCSRVBuilder creates a srvResolverBuilder which is used to factory SRV-DNS resolvers. 36 func NewGRPCSRVBuilder() resolver.Builder { 37 return &srvResolverBuilder{} 38 } 39 40 type srvResolverBuilder struct{} 41 42 func (*srvResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { 43 host, port, err := parseTarget(target.Endpoint(), GRPCSrvDefaultPort) 44 if err != nil { 45 return nil, err 46 } 47 48 // IP address. 49 if _, ok := formatIP(host); ok { 50 return nil, ErrIPAddressNotAllowed 51 } 52 53 d := &srvResolver{ 54 ctx: cancel.New(context.Background()), 55 rn: make(chan struct{}, 1), 56 host: host, 57 port: port, 58 cc: cc, 59 } 60 61 d.wg.Add(1) 62 go d.watcher() 63 return d, nil 64 } 65 66 func (*srvResolverBuilder) Scheme() string { return "dns-srv" } 67 68 // srvResolver watches for the name resolution update for a non-IP target. 69 type srvResolver struct { 70 host string 71 port string 72 ctx cancel.Context 73 cc resolver.ClientConn 74 state resolver.State 75 // rn channel is used by ResolveNow() to force an immediate resolution of the target. 76 rn chan struct{} 77 // wg is used to enforce Close() to return after the watcher() goroutine has finished. 78 // Otherwise, data race will be possible. [Race Example] in dns_resolver_test we 79 // replace the real lookup functions with mocked ones to facilitate testing. 80 // If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes 81 // will warns lookup (READ the lookup function pointers) inside watcher() goroutine 82 // has data race with replaceNetFunc (WRITE the lookup function pointers). 83 wg sync.WaitGroup 84 } 85 86 // ResolveNow invoke an immediate resolution of the target that this srvResolver watches. 87 func (d *srvResolver) ResolveNow(resolver.ResolveNowOptions) { 88 select { 89 case d.rn <- struct{}{}: 90 default: 91 } 92 } 93 94 // Close closes the srvResolver. 95 func (d *srvResolver) Close() { 96 d.ctx.Cancel() 97 d.wg.Wait() 98 } 99 100 func (d *srvResolver) watcher() { 101 defer d.wg.Done() 102 103 ticker := time.NewTicker(time.Minute * 60) 104 backOff := &retry.ExponentialBackOff{ 105 Min: time.Second, 106 Max: 120 * time.Second, 107 Factor: 1.6, 108 } 109 var lastSuccess time.Time 110 111 for { 112 // Avoid constantly re-resolving if multiple connections make ResolveNow() calls 113 if time.Since(lastSuccess) < time.Second*30 { 114 goto wait 115 } 116 117 if err := d.lookupSRV(); err != nil { 118 d.cc.ReportError(err) 119 next := backOff.NextIteration() 120 GRPCSrvDefaultLogger.WithError(err).WithField("retry-after", next). 121 Error("dns lookup failed; retrying...") 122 123 timer := time.NewTimer(next) 124 select { 125 case <-d.ctx.Done(): 126 timer.Stop() 127 return 128 case <-timer.C: 129 } 130 continue 131 } 132 lastSuccess = time.Now() 133 wait: 134 backOff.Reset() 135 136 select { 137 case <-d.ctx.Done(): 138 ticker.Stop() 139 return 140 case <-ticker.C: 141 case <-d.rn: 142 } 143 } 144 } 145 146 func (d *srvResolver) lookupSRV() error { 147 var result []resolver.Address 148 149 // TODO(thrawn01): At some point in the future we might parse the Target and determine 150 // if the Target name is in the RFC 2782 form of `_<service>._tcp[.service][.<datacenter>].<domain>` 151 // then fill out the service and proto fields in LookupSRV() 152 153 _, srvs, err := net.DefaultResolver.LookupSRV(d.ctx, "", "", d.host) 154 if err != nil { 155 return fmt.Errorf("SRV record lookup err: %w", err) 156 } 157 for _, s := range srvs { 158 resolved, err := net.DefaultResolver.LookupHost(d.ctx, s.Target) 159 if err != nil { 160 GRPCSrvDefaultLogger.WithError(err).WithField("target", s.Target). 161 Error("error resolving 'A' records for SRV entry") 162 continue 163 } 164 165 var addresses []resolver.Address 166 for _, a := range resolved { 167 ip, ok := formatIP(a) 168 if !ok { 169 GRPCSrvDefaultLogger.WithField("ip", ip). 170 Error("error parsing 'A' record for SRV entries; is not a valid ip address") 171 continue 172 } 173 addresses = append(addresses, resolver.Address{Addr: ip + ":" + strconv.Itoa(int(s.Port)), ServerName: s.Target}) 174 } 175 176 // If our current state is empty, then immediately update state before looking up the remaining SRV records. 177 // Looking up a lot of hosts could take a lot of time, and we want to connect as quickly as possible. 178 // During testing, a DNS lookup on all service nodes for `ratelimits` took 5+ seconds, which caused the 179 // GRPC calls to timeout. 180 if len(d.state.Addresses) == 0 { 181 if err := d.cc.UpdateState(resolver.State{Addresses: addresses}); err != nil { 182 GRPCSrvDefaultLogger.WithError(err).Error("UpdateState() call returned an error") 183 } 184 d.state.Addresses = addresses 185 } 186 result = append(result, addresses...) 187 } 188 189 if len(result) == 0 { 190 return fmt.Errorf("SRV record for '%s' contained no valid domain names", d.host) 191 } 192 193 d.state.Addresses = result 194 if GRPCSrvLogAddresses { 195 var addresses []string 196 for _, a := range result { 197 addresses = append(addresses, a.Addr) 198 } 199 GRPCSrvDefaultLogger.WithField("addresses", addresses).Info("dns-srv: address list updated") 200 } 201 return d.cc.UpdateState(d.state) 202 } 203 204 // formatIP returns ok = false if addr is not a valid textual representation of an IP address. 205 // If addr is an IPv4 address, return the addr and ok = true. 206 // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true. 207 func formatIP(addr string) (addrIP string, ok bool) { 208 ip := net.ParseIP(addr) 209 if ip == nil { 210 return "", false 211 } 212 if ip.To4() != nil { 213 return addr, true 214 } 215 return "[" + addr + "]", true 216 } 217 218 // parseTarget takes the user input target string and default port, returns formatted host and port info. 219 // If target doesn't specify a port, set the port to be the defaultPort. 220 // If target is in IPv6 format and host-name is enclosed in square brackets, brackets 221 // are stripped when setting the host. 222 // examples: 223 // target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443" 224 // target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80" 225 // target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443" 226 // target: ":80" defaultPort: "443" returns host: "localhost", port: "80" 227 func parseTarget(target, defaultPort string) (host, port string, err error) { 228 if target == "" { 229 return "", "", ErrMissingAddr 230 } 231 if ip := net.ParseIP(target); ip != nil { 232 // target is an IPv4 or IPv6(without brackets) address 233 return target, defaultPort, nil 234 } 235 if host, port, err = net.SplitHostPort(target); err == nil { 236 if port == "" { 237 // If the port field is empty (target ends with colon), e.g. "[::1]:", this is an error. 238 return "", "", ErrEndsWithColon 239 } 240 // target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port 241 if host == "" { 242 // Keep consistent with net.Dial(): If the host is empty, as in ":80", the local system is assumed. 243 host = "localhost" 244 } 245 return host, port, nil 246 } 247 if host, port, err = net.SplitHostPort(target + ":" + defaultPort); err == nil { 248 // target doesn't have port 249 return host, port, nil 250 } 251 return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err) 252 }