github.com/letsencrypt/boulder@v0.20251208.0/grpc/internal/resolver/dns/dns_resolver.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 // Forked from the default internal DNS resolver in the grpc-go package. The 20 // original source can be found at: 21 // https://github.com/grpc/grpc-go/blob/v1.49.0/internal/resolver/dns/dns_resolver.go 22 23 package dns 24 25 import ( 26 "context" 27 "errors" 28 "fmt" 29 "net" 30 "net/netip" 31 "strconv" 32 "strings" 33 "sync" 34 "time" 35 36 "google.golang.org/grpc/grpclog" 37 "google.golang.org/grpc/resolver" 38 "google.golang.org/grpc/serviceconfig" 39 40 "github.com/letsencrypt/boulder/bdns" 41 "github.com/letsencrypt/boulder/grpc/internal/backoff" 42 "github.com/letsencrypt/boulder/grpc/noncebalancer" 43 ) 44 45 var logger = grpclog.Component("srv") 46 47 // Globals to stub out in tests. TODO: Perhaps these two can be combined into a 48 // single variable for testing the resolver? 49 var ( 50 newTimer = time.NewTimer 51 newTimerDNSResRate = time.NewTimer 52 ) 53 54 func init() { 55 resolver.Register(NewDefaultSRVBuilder()) 56 resolver.Register(NewNonceSRVBuilder()) 57 } 58 59 const defaultDNSSvrPort = "53" 60 61 var defaultResolver netResolver = net.DefaultResolver 62 63 var ( 64 // To prevent excessive re-resolution, we enforce a rate limit on DNS 65 // resolution requests. 66 minDNSResRate = 30 * time.Second 67 ) 68 69 var customAuthorityDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) { 70 return func(ctx context.Context, network, address string) (net.Conn, error) { 71 var dialer net.Dialer 72 return dialer.DialContext(ctx, network, authority) 73 } 74 } 75 76 var customAuthorityResolver = func(authority string) (*net.Resolver, error) { 77 host, port, err := bdns.ParseTarget(authority, defaultDNSSvrPort) 78 if err != nil { 79 return nil, err 80 } 81 return &net.Resolver{ 82 PreferGo: true, 83 Dial: customAuthorityDialer(net.JoinHostPort(host, port)), 84 }, nil 85 } 86 87 // NewDefaultSRVBuilder creates a srvBuilder which is used to factory SRV DNS 88 // resolvers. 89 func NewDefaultSRVBuilder() resolver.Builder { 90 return &srvBuilder{scheme: "srv"} 91 } 92 93 // NewNonceSRVBuilder creates a srvBuilder which is used to factory SRV DNS 94 // resolvers with a custom grpc.Balancer used by nonce-service clients. 95 func NewNonceSRVBuilder() resolver.Builder { 96 return &srvBuilder{scheme: noncebalancer.SRVResolverScheme, balancer: noncebalancer.Name} 97 } 98 99 type srvBuilder struct { 100 scheme string 101 balancer string 102 } 103 104 // Build creates and starts a DNS resolver that watches the name resolution of the target. 105 func (b *srvBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { 106 var names []name 107 for i := range strings.SplitSeq(target.Endpoint(), ",") { 108 service, domain, err := parseServiceDomain(i) 109 if err != nil { 110 return nil, err 111 } 112 names = append(names, name{service: service, domain: domain}) 113 } 114 115 ctx, cancel := context.WithCancel(context.Background()) 116 d := &dnsResolver{ 117 names: names, 118 ctx: ctx, 119 cancel: cancel, 120 cc: cc, 121 rn: make(chan struct{}, 1), 122 } 123 124 if target.URL.Host == "" { 125 d.resolver = defaultResolver 126 } else { 127 var err error 128 d.resolver, err = customAuthorityResolver(target.URL.Host) 129 if err != nil { 130 return nil, err 131 } 132 } 133 134 if b.balancer != "" { 135 d.serviceConfig = cc.ParseServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, b.balancer)) 136 } 137 138 d.wg.Add(1) 139 go d.watcher() 140 return d, nil 141 } 142 143 // Scheme returns the naming scheme of this resolver builder. 144 func (b *srvBuilder) Scheme() string { 145 return b.scheme 146 } 147 148 type netResolver interface { 149 LookupHost(ctx context.Context, host string) (addrs []string, err error) 150 LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error) 151 } 152 153 type name struct { 154 service string 155 domain string 156 } 157 158 // dnsResolver watches for the name resolution update for a non-IP target. 159 type dnsResolver struct { 160 names []name 161 resolver netResolver 162 ctx context.Context 163 cancel context.CancelFunc 164 cc resolver.ClientConn 165 // rn channel is used by ResolveNow() to force an immediate resolution of the target. 166 rn chan struct{} 167 // wg is used to enforce Close() to return after the watcher() goroutine has finished. 168 // Otherwise, data race will be possible. [Race Example] in dns_resolver_test we 169 // replace the real lookup functions with mocked ones to facilitate testing. 170 // If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes 171 // will warns lookup (READ the lookup function pointers) inside watcher() goroutine 172 // has data race with replaceNetFunc (WRITE the lookup function pointers). 173 wg sync.WaitGroup 174 serviceConfig *serviceconfig.ParseResult 175 } 176 177 // ResolveNow invoke an immediate resolution of the target that this dnsResolver watches. 178 func (d *dnsResolver) ResolveNow(resolver.ResolveNowOptions) { 179 select { 180 case d.rn <- struct{}{}: 181 default: 182 } 183 } 184 185 // Close closes the dnsResolver. 186 func (d *dnsResolver) Close() { 187 d.cancel() 188 d.wg.Wait() 189 } 190 191 func (d *dnsResolver) watcher() { 192 defer d.wg.Done() 193 backoffIndex := 1 194 for { 195 state, err := d.lookup() 196 if err != nil { 197 // Report error to the underlying grpc.ClientConn. 198 d.cc.ReportError(err) 199 } else { 200 if d.serviceConfig != nil { 201 state.ServiceConfig = d.serviceConfig 202 } 203 err = d.cc.UpdateState(*state) 204 } 205 206 var timer *time.Timer 207 if err == nil { 208 // Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least 209 // to prevent constantly re-resolving. 210 backoffIndex = 1 211 timer = newTimerDNSResRate(minDNSResRate) 212 select { 213 case <-d.ctx.Done(): 214 timer.Stop() 215 return 216 case <-d.rn: 217 } 218 } else { 219 // Poll on an error found in DNS Resolver or an error received from ClientConn. 220 timer = newTimer(backoff.DefaultExponential.Backoff(backoffIndex)) 221 backoffIndex++ 222 } 223 select { 224 case <-d.ctx.Done(): 225 timer.Stop() 226 return 227 case <-timer.C: 228 } 229 } 230 } 231 232 func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) { 233 var newAddrs []resolver.Address 234 var errs []error 235 for _, n := range d.names { 236 _, srvs, err := d.resolver.LookupSRV(d.ctx, n.service, "tcp", n.domain) 237 if err != nil { 238 err = handleDNSError(err, "SRV") // may become nil 239 if err != nil { 240 errs = append(errs, err) 241 continue 242 } 243 } 244 for _, s := range srvs { 245 backendAddrs, err := d.resolver.LookupHost(d.ctx, s.Target) 246 if err != nil { 247 err = handleDNSError(err, "A") // may become nil 248 if err != nil { 249 errs = append(errs, err) 250 continue 251 } 252 } 253 for _, a := range backendAddrs { 254 ip, ok := formatIP(a) 255 if !ok { 256 errs = append(errs, fmt.Errorf("srv: error parsing A record IP address %v", a)) 257 continue 258 } 259 addr := ip + ":" + strconv.Itoa(int(s.Port)) 260 newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target}) 261 } 262 } 263 } 264 // Only return an error if all lookups failed. 265 if len(errs) > 0 && len(newAddrs) == 0 { 266 return nil, errors.Join(errs...) 267 } 268 return newAddrs, nil 269 } 270 271 func handleDNSError(err error, lookupType string) error { 272 if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary { 273 // Timeouts and temporary errors should be communicated to gRPC to 274 // attempt another DNS query (with backoff). Other errors should be 275 // suppressed (they may represent the absence of a TXT record). 276 return nil 277 } 278 if err != nil { 279 err = fmt.Errorf("srv: %v record lookup error: %v", lookupType, err) 280 logger.Info(err) 281 } 282 return err 283 } 284 285 func (d *dnsResolver) lookup() (*resolver.State, error) { 286 addrs, err := d.lookupSRV() 287 if err != nil { 288 return nil, err 289 } 290 return &resolver.State{Addresses: addrs}, nil 291 } 292 293 // formatIP returns ok = false if addr is not a valid textual representation of an IP address. 294 // If addr is an IPv4 address, return the addr and ok = true. 295 // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true. 296 func formatIP(addr string) (addrIP string, ok bool) { 297 ip, err := netip.ParseAddr(addr) 298 if err != nil { 299 return "", false 300 } 301 if ip.Is4() { 302 return addr, true 303 } 304 return "[" + addr + "]", true 305 } 306 307 // parseServiceDomain takes the user input target string and parses the service domain 308 // names for SRV lookup. Input is expected to be a hostname containing at least 309 // two labels (e.g. "foo.bar", "foo.bar.baz"). The first label is the service 310 // name and the rest is the domain name. If the target is not in the expected 311 // format, an error is returned. 312 func parseServiceDomain(target string) (string, string, error) { 313 sd := strings.SplitN(target, ".", 2) 314 if len(sd) < 2 || sd[0] == "" || sd[1] == "" { 315 return "", "", fmt.Errorf("srv: hostname %q contains < 2 labels", target) 316 } 317 return sd[0], sd[1], nil 318 }