github.com/kaydxh/golang@v0.0.131/go/net/resolver/dns/dns_resolver.go (about) 1 /* 2 *Copyright (c) 2022, kaydxh 3 * 4 *Permission is hereby granted, free of charge, to any person obtaining a copy 5 *of this software and associated documentation files (the "Software"), to deal 6 *in the Software without restriction, including without limitation the rights 7 *to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 *copies of the Software, and to permit persons to whom the Software is 9 *furnished to do so, subject to the following conditions: 10 * 11 *The above copyright notice and this permission notice shall be included in all 12 *copies or substantial portions of the Software. 13 * 14 *THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 *IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 *FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 *AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 *LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 *OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 *SOFTWARE. 21 */ 22 package dns 23 24 import ( 25 "context" 26 "fmt" 27 "net" 28 "strconv" 29 "sync" 30 "time" 31 32 rand_ "github.com/kaydxh/golang/go/math/rand" 33 net_ "github.com/kaydxh/golang/go/net" 34 "github.com/kaydxh/golang/go/net/resolver" 35 time_ "github.com/kaydxh/golang/go/time" 36 ) 37 38 // EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB 39 // addresses from SRV records. Must not be changed after init time. 40 var EnableSRVLookups = false 41 42 // Globals to stub out in tests. TODO: Perhaps these two can be combined into a 43 // single variable for testing the resolver? 44 var ( 45 newTimer = time.NewTimer 46 ) 47 48 func init() { 49 resolver.Register(NewBuilder()) 50 } 51 52 const ( 53 defaultPort = "443" 54 defaultDNSSvrPort = "53" 55 ) 56 57 var ( 58 defaultResolver netResolver = net.DefaultResolver 59 // To prevent excessive re-resolution, we enforce a rate limit on DNS 60 // resolution requests. 61 defaultSyncInterval = 30 * time.Second 62 ) 63 64 var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) { 65 return func(ctx context.Context, network, address string) (net.Conn, error) { 66 var dialer net.Dialer 67 return dialer.DialContext(ctx, network, authority) 68 } 69 } 70 71 var customAuthorityResolver = func(authority string) (netResolver, error) { 72 host, port, err := net_.ParseTarget(authority, defaultDNSSvrPort) 73 if err != nil { 74 return nil, err 75 } 76 77 authorityWithPort := net.JoinHostPort(host, port) 78 79 return &net.Resolver{ 80 PreferGo: true, 81 Dial: customAuthorityDialler(authorityWithPort), 82 }, nil 83 } 84 85 // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers. 86 func NewBuilder(opts ...dnsBuilderOption) resolver.Builder { 87 b := &dnsBuilder{} 88 b.ApplyOptions(opts...) 89 if b.opts.syncInterval == 0 { 90 b.opts.syncInterval = defaultSyncInterval 91 } 92 93 return b 94 } 95 96 type dnsBuilder struct { 97 opts struct { 98 syncInterval time.Duration 99 } 100 } 101 102 // Build creates and starts a DNS resolver that watches the name resolution of the target. 103 func (b *dnsBuilder) Build(target resolver.Target, opts ...resolver.ResolverBuildOption) (resolver.Resolver, error) { 104 var opt resolver.ResolverBuildOptions 105 opt.ApplyOptions(opts...) 106 host, port, err := net_.ParseTarget(target.Endpoint, defaultPort) 107 if err != nil { 108 return nil, err 109 } 110 cc := opt.Cc 111 112 // IP address. 113 if ipAddr, ok := formatIP(host); ok { 114 addr := []resolver.Address{{Addr: ipAddr + ":" + port}} 115 if cc != nil { 116 cc.UpdateState(resolver.State{Addresses: addr}) 117 } 118 return deadResolver{ 119 addrs: addr, 120 }, nil 121 } 122 123 // DNS address (non-IP). 124 ctx, cancel := context.WithCancel(context.Background()) 125 d := &dnsResolver{ 126 host: host, 127 port: port, 128 syncInterval: b.opts.syncInterval, 129 ctx: ctx, 130 cancel: cancel, 131 cc: cc, 132 rn: make(chan struct{}, 1), 133 } 134 135 if target.Authority == "" { 136 d.resolver = defaultResolver 137 } else { 138 d.resolver, err = customAuthorityResolver(target.Authority) 139 if err != nil { 140 return nil, err 141 } 142 } 143 144 d.wg.Add(1) 145 go d.watcher() 146 return d, nil 147 } 148 149 // Scheme returns the naming scheme of this resolver builder, which is "dns". 150 func (b *dnsBuilder) Scheme() string { 151 return "dns" 152 } 153 154 type netResolver interface { 155 LookupHost(ctx context.Context, host string) (addrs []string, err error) 156 LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error) 157 LookupTXT(ctx context.Context, name string) (txts []string, err error) 158 } 159 160 // deadResolver is a resolver that does nothing. 161 type deadResolver struct { 162 addrs []resolver.Address 163 } 164 165 func (d deadResolver) ResolveOne(opts ...resolver.ResolveOneOption) (resolver.Address, error) { 166 var opt resolver.ResolveOneOptions 167 opt.ApplyOptions(opts...) 168 169 addrs, err := d.ResolveAll(resolver.WithIPTypeForResolveAll(opt.IPType)) 170 if err != nil { 171 return resolver.Address{}, err 172 } 173 174 switch opt.PickMode { 175 case resolver.Resolver_pick_mode_random: 176 return addrs[rand_.Intn(len(addrs))], nil 177 case resolver.Resolver_pick_mode_first: 178 return addrs[0], nil 179 default: 180 return addrs[rand_.Intn(len(addrs))], nil 181 182 } 183 } 184 185 func (d deadResolver) ResolveAll(opts ...resolver.ResolveAllOption) ([]resolver.Address, error) { 186 var opt resolver.ResolveAllOptions 187 opt.ApplyOptions(opts...) 188 if len(d.addrs) == 0 { 189 return nil, fmt.Errorf("resolve target's addresses are empty") 190 } 191 192 var pickAddrs []resolver.Address 193 if opt.IPType == resolver.Resolver_ip_type_all { 194 pickAddrs = d.addrs 195 } else { 196 for _, addr := range d.addrs { 197 v4 := (opt.IPType == resolver.Resolver_ip_type_v4) 198 ip, _, _ := net_.SplitHostIntPort(addr.Addr) 199 if net_.IsIPv4String(ip) { 200 if v4 { 201 pickAddrs = append(pickAddrs, addr) 202 } 203 } else { 204 //v6 205 if !v4 { 206 pickAddrs = append(pickAddrs, addr) 207 } 208 } 209 } 210 } 211 if len(pickAddrs) == 0 { 212 return nil, fmt.Errorf("resolve target's addresses type[%v] are empty", opt.IPType) 213 } 214 return pickAddrs, nil 215 } 216 217 func (deadResolver) ResolveNow(opts ...resolver.ResolveNowOption) {} 218 219 func (deadResolver) Close() {} 220 221 // dnsResolver watches for the name resolution update for a non-IP target. 222 type dnsResolver struct { 223 host string 224 port string 225 resolver netResolver 226 syncInterval time.Duration 227 228 ctx context.Context 229 cancel context.CancelFunc 230 cc resolver.ClientConn 231 // rn channel is used by ResolveNow() to force an immediate resolution of the target. 232 rn chan struct{} 233 // wg is used to enforce Close() to return after the watcher() goroutine has finished. 234 // Otherwise, data race will be possible. [Race Example] in dns_resolver_test we 235 // replace the real lookup functions with mocked ones to facilitate testing. 236 // If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes 237 // will warns lookup (READ the lookup function pointers) inside watcher() goroutine 238 // has data race with replaceNetFunc (WRITE the lookup function pointers). 239 wg sync.WaitGroup 240 } 241 242 func (d *dnsResolver) ResolveOne(opts ...resolver.ResolveOneOption) (resolver.Address, error) { 243 var opt resolver.ResolveOneOptions 244 opt.ApplyOptions(opts...) 245 246 addrs, err := d.ResolveAll(resolver.WithIPTypeForResolveAll(opt.IPType)) 247 if err != nil { 248 return resolver.Address{}, err 249 } 250 251 switch opt.PickMode { 252 case resolver.Resolver_pick_mode_random: 253 return addrs[rand_.Intn(len(addrs))], nil 254 case resolver.Resolver_pick_mode_first: 255 return addrs[0], nil 256 default: 257 return addrs[rand_.Intn(len(addrs))], nil 258 259 } 260 } 261 262 func (d *dnsResolver) ResolveAll(opts ...resolver.ResolveAllOption) ([]resolver.Address, error) { 263 var opt resolver.ResolveAllOptions 264 opt.ApplyOptions(opts...) 265 d.ResolveNow() 266 addrs, err := d.lookupHost() 267 if err != nil { 268 return nil, err 269 } 270 if len(addrs) == 0 { 271 return nil, fmt.Errorf("resolve target's addresses are empty") 272 } 273 274 var pickAddrs []resolver.Address 275 if opt.IPType == resolver.Resolver_ip_type_all { 276 pickAddrs = addrs 277 } else { 278 for _, addr := range addrs { 279 v4 := (opt.IPType == resolver.Resolver_ip_type_v4) 280 ip, _, _ := net_.SplitHostIntPort(addr.Addr) 281 if net_.IsIPv4String(ip) { 282 if v4 { 283 pickAddrs = append(pickAddrs, addr) 284 } 285 } else { 286 //v6 287 if !v4 { 288 pickAddrs = append(pickAddrs, addr) 289 } 290 } 291 } 292 } 293 if len(pickAddrs) == 0 { 294 return nil, fmt.Errorf("resolve target's addresses type[%v] are empty", opt.IPType) 295 } 296 return pickAddrs, nil 297 } 298 299 // ResolveNow invoke an immediate resolution of the target that this dnsResolver watches. 300 func (d *dnsResolver) ResolveNow(opts ...resolver.ResolveNowOption) { 301 select { 302 case d.rn <- struct{}{}: 303 default: 304 } 305 } 306 307 // Close closes the dnsResolver. 308 func (d *dnsResolver) Close() { 309 d.cancel() 310 d.wg.Wait() 311 } 312 313 func (d *dnsResolver) watcher() { 314 defer d.wg.Done() 315 316 backoff := time_.NewExponentialBackOff() 317 for { 318 addrs, err := d.lookupHost() 319 if d.cc != nil { 320 if err != nil { 321 // Report error to the underlying grpc.ClientConn. 322 d.cc.ReportError(err) 323 } else { 324 err = d.cc.UpdateState(resolver.State{Addresses: addrs}) 325 } 326 } 327 328 var timer *time.Timer 329 if err == nil { 330 // Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least 331 // to prevent constantly re-resolving. 332 backoff.Reset() 333 timer = newTimer(d.syncInterval) 334 select { 335 case <-d.ctx.Done(): 336 timer.Stop() 337 return 338 case <-d.rn: 339 } 340 } else { 341 // Poll on an error found in DNS Resolver or an error received from ClientConn. 342 actualInterval, _ := backoff.NextBackOff() 343 timer = newTimer(actualInterval) 344 } 345 select { 346 case <-d.ctx.Done(): 347 timer.Stop() 348 return 349 case <-timer.C: 350 } 351 } 352 } 353 354 func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) { 355 if !EnableSRVLookups { 356 return nil, nil 357 } 358 var newAddrs []resolver.Address 359 _, srvs, err := d.resolver.LookupSRV(d.ctx, "grpclb", "tcp", d.host) 360 if err != nil { 361 err = handleDNSError(err, "SRV") // may become nil 362 return nil, err 363 } 364 for _, s := range srvs { 365 lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target) 366 if err != nil { 367 err = handleDNSError(err, "A") // may become nil 368 if err == nil { 369 // If there are other SRV records, look them up and ignore this 370 // one that does not exist. 371 continue 372 } 373 return nil, err 374 } 375 for _, a := range lbAddrs { 376 ip, ok := formatIP(a) 377 if !ok { 378 return nil, fmt.Errorf("dns: error parsing A record IP address %v", a) 379 } 380 addr := ip + ":" + strconv.Itoa(int(s.Port)) 381 newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target}) 382 } 383 } 384 return newAddrs, nil 385 } 386 387 func handleDNSError(err error, lookupType string) error { 388 if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary { 389 // Timeouts and temporary errors should be communicated to gRPC to 390 // attempt another DNS query (with backoff). Other errors should be 391 // suppressed (they may represent the absence of a TXT record). 392 return nil 393 } 394 if err != nil { 395 err = fmt.Errorf("dns: %v record lookup error: %v", lookupType, err) 396 } 397 return err 398 } 399 400 func (d *dnsResolver) lookupHost() ([]resolver.Address, error) { 401 addrs, err := d.resolver.LookupHost(d.ctx, d.host) 402 if err != nil { 403 err = handleDNSError(err, "A") 404 return nil, err 405 } 406 newAddrs := make([]resolver.Address, 0, len(addrs)) 407 for _, a := range addrs { 408 ip, ok := formatIP(a) 409 if !ok { 410 return nil, fmt.Errorf("dns: error parsing A record IP address %v", a) 411 } 412 addr := ip + ":" + d.port 413 newAddrs = append(newAddrs, resolver.Address{Addr: addr}) 414 } 415 return newAddrs, nil 416 } 417 418 // formatIP returns ok = false if addr is not a valid textual representation of an IP address. 419 // If addr is an IPv4 address, return the addr and ok = true. 420 // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true. 421 func formatIP(addr string) (addrIP string, ok bool) { 422 ip := net.ParseIP(addr) 423 if ip == nil { 424 return "", false 425 } 426 if ip.To4() != nil { 427 return addr, true 428 } 429 return "[" + addr + "]", true 430 }