github.com/searKing/golang/go@v1.2.117/net/resolver/dns/dns_resolver.go (about) 1 // Copyright 2021 The searKing Author. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package dns 6 7 import ( 8 "context" 9 "errors" 10 "fmt" 11 "net" 12 "strconv" 13 "sync" 14 "time" 15 16 rand_ "github.com/searKing/golang/go/math/rand" 17 "github.com/searKing/golang/go/net/resolver" 18 time_ "github.com/searKing/golang/go/time" 19 ) 20 21 // EnableSRVLookups controls whether the DNS resolver attempts to fetch 22 // addresses from SRV records. Must not be changed after init time. 23 var EnableSRVLookups = false 24 25 // Globals to stub out in tests. 26 var newTimer = time.NewTimer 27 28 func init() { 29 resolver.Register(NewBuilder()) 30 } 31 32 const ( 33 defaultPort = "443" 34 defaultDNSSvrPort = "53" 35 ) 36 37 var ( 38 errMissingAddr = errors.New("dns resolver: missing address") 39 40 // Addresses ending with a colon that is supposed to be the separator 41 // between host and port is not allowed. E.g. "::" is a valid address as 42 // it is an IPv6 address (host only) and "[::]:" is invalid as it ends with 43 // a colon as the host and port separator 44 errEndsWithColon = errors.New("dns resolver: missing port after port-separator colon") 45 ) 46 47 var ( 48 defaultResolver netResolver = net.DefaultResolver 49 // To prevent excessive re-resolution, we enforce a rate limit on DNS 50 // resolution requests. 51 minDNSResRate = 30 * time.Second 52 ) 53 54 var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) { 55 return func(ctx context.Context, network, address string) (net.Conn, error) { 56 var dialer net.Dialer 57 return dialer.DialContext(ctx, network, authority) 58 } 59 } 60 61 var customAuthorityResolver = func(authority string) (netResolver, error) { 62 host, port, err := parseTarget(authority, defaultDNSSvrPort) 63 if err != nil { 64 return nil, err 65 } 66 67 authorityWithPort := net.JoinHostPort(host, port) 68 69 return &net.Resolver{ 70 PreferGo: true, 71 Dial: customAuthorityDialler(authorityWithPort), 72 }, nil 73 } 74 75 // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers. 76 func NewBuilder() resolver.Builder { 77 return &dnsBuilder{} 78 } 79 80 type dnsBuilder struct{} 81 82 // Build creates and starts a DNS resolver that watches the name resolution of the target. 83 func (b *dnsBuilder) Build(ctx context.Context, target resolver.Target, opts ...resolver.BuildOption) (resolver.Resolver, error) { 84 var opt resolver.Build 85 opt.ApplyOptions(opts...) 86 host, port, err := parseTarget(target.Endpoint, defaultPort) 87 if err != nil { 88 return nil, err 89 } 90 cc := opt.ClientConn 91 92 // IP address. 93 if ipAddr, ok := formatIP(host); ok { 94 addr := []resolver.Address{{Addr: ipAddr + ":" + port}} 95 if cc != nil { 96 _ = cc.UpdateState(resolver.State{Addresses: addr}) 97 } 98 return deadResolver{ 99 addrs: addr, 100 }, nil 101 } 102 103 // DNS address (non-IP). 104 ctx, cancel := context.WithCancel(context.Background()) 105 d := &dnsResolver{ 106 host: host, 107 port: port, 108 ctx: ctx, 109 cancel: cancel, 110 cc: cc, 111 rn: make(chan struct{}, 1), 112 } 113 114 if target.Authority == "" { 115 d.resolver = defaultResolver 116 } else { 117 d.resolver, err = customAuthorityResolver(target.Authority) 118 if err != nil { 119 return nil, err 120 } 121 } 122 123 d.wg.Add(1) 124 go d.watcher() 125 return d, nil 126 } 127 128 // Scheme returns the naming scheme of this resolver builder, which is "dns". 129 func (b *dnsBuilder) Scheme() string { 130 return "dns" 131 } 132 133 type netResolver interface { 134 LookupHost(ctx context.Context, host string) (addrs []string, err error) 135 LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error) 136 LookupTXT(ctx context.Context, name string) (txts []string, err error) 137 } 138 139 // deadResolver is a resolver that does nothing. 140 type deadResolver struct { 141 picker resolver.Picker 142 addrs []resolver.Address 143 } 144 145 func (d deadResolver) ResolveOneAddr(ctx context.Context, opts ...resolver.ResolveOneAddrOption) (resolver.Address, error) { 146 if len(d.addrs) == 0 { 147 return resolver.Address{}, fmt.Errorf("resolve target, but no addr") 148 } 149 return d.addrs[rand_.Intn(len(d.addrs))], nil 150 } 151 func (d deadResolver) ResolveAddr(ctx context.Context, opts ...resolver.ResolveAddrOption) ([]resolver.Address, error) { 152 return d.addrs, nil 153 } 154 func (deadResolver) ResolveNow(ctx context.Context, opts ...resolver.ResolveNowOption) {} 155 156 func (deadResolver) Close() {} 157 158 // dnsResolver watches for the name resolution update for a non-IP target. 159 type dnsResolver struct { 160 host string 161 port string 162 resolver netResolver 163 ctx context.Context 164 cancel context.CancelFunc 165 cc resolver.ClientConn 166 // rn channel is used by ResolveNow() to force an immediate resolution of the target. 167 rn chan struct{} 168 // wg is used to enforce Close() to return after the watcher() goroutine has finished. 169 // Otherwise, data race will be possible. [Race Example] in dns_resolver_test we 170 // replace the real lookup functions with mocked ones to facilitate testing. 171 // If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes 172 // will warns lookup (READ the lookup function pointers) inside watcher() goroutine 173 // has data race with replaceNetFunc (WRITE the lookup function pointers). 174 wg sync.WaitGroup 175 } 176 177 func (d *dnsResolver) ResolveOneAddr(ctx context.Context, opts ...resolver.ResolveOneAddrOption) (resolver.Address, error) { 178 d.ResolveNow(ctx) 179 addrs, err := d.lookupHost() 180 if err != nil { 181 return resolver.Address{}, err 182 } 183 if len(addrs) == 0 { 184 return resolver.Address{}, fmt.Errorf("resolve target, but no addr") 185 } 186 return addrs[rand_.Intn(len(addrs))], nil 187 } 188 189 func (d *dnsResolver) ResolveAddr(ctx context.Context, opts ...resolver.ResolveAddrOption) ([]resolver.Address, error) { 190 d.ResolveNow(ctx) 191 return d.lookupHost() 192 } 193 194 // ResolveNow invoke an immediate resolution of the target that this dnsResolver watches. 195 func (d *dnsResolver) ResolveNow(ctx context.Context, opts ...resolver.ResolveNowOption) { 196 select { 197 case d.rn <- struct{}{}: 198 default: 199 } 200 } 201 202 // Close closes the dnsResolver. 203 func (d *dnsResolver) Close() { 204 d.cancel() 205 d.wg.Wait() 206 } 207 208 func (d *dnsResolver) watcher() { 209 defer d.wg.Done() 210 211 backoff := time_.NewGrpcExponentialBackOff() 212 for { 213 addrs, err := d.lookupHost() 214 if d.cc != nil { 215 if err != nil { 216 // Report error to the underlying grpc.ClientConn. 217 d.cc.ReportError(err) 218 } else { 219 err = d.cc.UpdateState(resolver.State{Addresses: addrs}) 220 } 221 } 222 223 var timer *time.Timer 224 if err == nil { 225 // Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least 226 // to prevent constantly re-resolving. 227 backoff.Reset() 228 timer = newTimer(minDNSResRate) 229 select { 230 case <-d.ctx.Done(): 231 timer.Stop() 232 return 233 case <-d.rn: 234 } 235 } else { 236 // Poll on an error found in DNS Resolver or an error received from ClientConn. 237 bc, _ := backoff.NextBackOff() 238 timer = newTimer(bc) 239 } 240 select { 241 case <-d.ctx.Done(): 242 timer.Stop() 243 return 244 case <-timer.C: 245 } 246 } 247 } 248 249 func (d *dnsResolver) lookupSRV(service, proto string) ([]string, error) { 250 if !EnableSRVLookups { 251 return nil, nil 252 } 253 var newAddrs []string 254 _, srvs, err := d.resolver.LookupSRV(d.ctx, service, proto, d.host) 255 if err != nil { 256 err = handleDNSError(err, "SRV") // may become nil 257 return nil, err 258 } 259 for _, s := range srvs { 260 lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target) 261 if err != nil { 262 err = handleDNSError(err, "A") // may become nil 263 if err == nil { 264 // If there are other SRV records, look them up and ignore this 265 // one that does not exist. 266 continue 267 } 268 return nil, err 269 } 270 for _, a := range lbAddrs { 271 ip, ok := formatIP(a) 272 if !ok { 273 return nil, fmt.Errorf("dns: error parsing A record IP address %v", a) 274 } 275 addr := ip + ":" + strconv.Itoa(int(s.Port)) 276 newAddrs = append(newAddrs, addr) 277 } 278 } 279 return newAddrs, nil 280 } 281 282 var filterError = func(err error) error { 283 if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary { 284 // Timeouts and temporary errors should be communicated to gRPC to 285 // attempt another DNS query (with backoff). Other errors should be 286 // suppressed (they may represent the absence of a TXT record). 287 return nil 288 } 289 return err 290 } 291 292 func handleDNSError(err error, lookupType string) error { 293 err = filterError(err) 294 if err != nil { 295 err = fmt.Errorf("dns: %v record lookup error: %w", lookupType, err) 296 return err 297 } 298 return nil 299 } 300 301 func (d *dnsResolver) lookupHost() ([]resolver.Address, error) { 302 var newAddrs []resolver.Address 303 addrs, err := d.resolver.LookupHost(d.ctx, d.host) 304 if err != nil { 305 err = handleDNSError(err, "A") 306 return nil, err 307 } 308 for _, a := range addrs { 309 ip, ok := formatIP(a) 310 if !ok { 311 return nil, fmt.Errorf("dns: error parsing A record IP address %v", a) 312 } 313 addr := ip + ":" + d.port 314 newAddrs = append(newAddrs, resolver.Address{Addr: addr}) 315 } 316 return newAddrs, nil 317 } 318 319 // formatIP returns ok = false if addr is not a valid textual representation of an IP address. 320 // If addr is an IPv4 address, return the addr and ok = true. 321 // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true. 322 func formatIP(addr string) (addrIP string, ok bool) { 323 ip := net.ParseIP(addr) 324 if ip == nil { 325 return "", false 326 } 327 if ip.To4() != nil { 328 return addr, true 329 } 330 return "[" + addr + "]", true 331 } 332 333 // parseTarget takes the user input target string and default port, returns formatted host and port info. 334 // If target doesn't specify a port, set the port to be the defaultPort. 335 // If target is in IPv6 format and host-name is enclosed in square brackets, brackets 336 // are stripped when setting the host. 337 // examples: 338 // target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443" 339 // target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80" 340 // target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443" 341 // target: ":80" defaultPort: "443" returns host: "localhost", port: "80" 342 func parseTarget(target, defaultPort string) (host, port string, err error) { 343 if target == "" { 344 return "", "", errMissingAddr 345 } 346 if ip := net.ParseIP(target); ip != nil { 347 // target is an IPv4 or IPv6(without brackets) address 348 return target, defaultPort, nil 349 } 350 if host, port, err = net.SplitHostPort(target); err == nil { 351 if port == "" { 352 // If the port field is empty (target ends with colon), e.g. "[::1]:", this is an error. 353 return "", "", errEndsWithColon 354 } 355 // target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port 356 if host == "" { 357 // Keep consistent with net.Dial(): If the host is empty, as in ":80", the local system is assumed. 358 host = "localhost" 359 } 360 return host, port, nil 361 } 362 if host, port, err = net.SplitHostPort(target + ":" + defaultPort); err == nil { 363 // target doesn't have port 364 return host, port, nil 365 } 366 return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err) 367 }