github.com/yaling888/clash@v1.53.0/dns/resolver.go (about) 1 package dns 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "math/rand/v2" 8 "net/netip" 9 "strings" 10 "sync" 11 "time" 12 13 D "github.com/miekg/dns" 14 "github.com/phuslu/log" 15 "github.com/samber/lo" 16 "go.uber.org/atomic" 17 "golang.org/x/sync/singleflight" 18 19 "github.com/yaling888/clash/common/cache" 20 "github.com/yaling888/clash/component/fakeip" 21 "github.com/yaling888/clash/component/geodata/router" 22 "github.com/yaling888/clash/component/resolver" 23 "github.com/yaling888/clash/component/trie" 24 C "github.com/yaling888/clash/constant" 25 ) 26 27 type dnsClient interface { 28 Exchange(m *D.Msg) (msg *rMsg, err error) 29 ExchangeContext(ctx context.Context, m *D.Msg) (msg *rMsg, err error) 30 IsLan() bool 31 } 32 33 type result struct { 34 Msg *rMsg 35 Error error 36 Policy bool 37 } 38 39 type rMsg struct { 40 Msg *D.Msg 41 Source string 42 Lan bool 43 } 44 45 func (m *rMsg) Copy() *rMsg { 46 m1 := new(rMsg) 47 m1.Msg = m.Msg.Copy() 48 m1.Source = m.Source 49 m1.Lan = m.Lan 50 return m1 51 } 52 53 var _ resolver.Resolver = (*Resolver)(nil) 54 55 type Resolver struct { 56 ipv6 bool 57 hosts *trie.DomainTrie[netip.Addr] 58 main []dnsClient 59 fallback []dnsClient 60 proxyServer []dnsClient 61 remote []dnsClient 62 fallbackDomainFilters []fallbackDomainFilter 63 fallbackIPFilters []fallbackIPFilter 64 group singleflight.Group 65 lruCache *cache.LruCache[string, *rMsg] 66 policy *trie.DomainTrie[*Policy] 67 searchDomains []string 68 } 69 70 // LookupIP request with TypeA and TypeAAAA, priority return TypeA 71 func (r *Resolver) LookupIP(ctx context.Context, host string) (ip []netip.Addr, err error) { 72 ctx1, cancel := context.WithCancel(ctx) 73 defer cancel() 74 75 ch := make(chan []netip.Addr, 1) 76 go func() { 77 defer close(ch) 78 ip6, err6 := r.lookupIP(ctx1, host, D.TypeAAAA) 79 if err6 != nil { 80 return 81 } 82 ch <- ip6 83 }() 84 85 ip, err = r.lookupIP(ctx1, host, D.TypeA) 86 if err == nil { 87 if resolver.IsRemote(ctx) { // force combine ipv6 list for remote resolve DNS 88 if ip6, open := <-ch; open { 89 ip = append(ip, ip6...) 90 } 91 } 92 return 93 } 94 95 ip, open := <-ch 96 if !open { 97 return nil, resolver.ErrIPNotFound 98 } 99 100 return ip, nil 101 } 102 103 // ResolveIP request with TypeA and TypeAAAA, priority return TypeA 104 func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) { 105 ips, err := r.LookupIP(context.Background(), host) 106 if err != nil { 107 return netip.Addr{}, err 108 } else if len(ips) == 0 { 109 return netip.Addr{}, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host) 110 } 111 return ips[rand.IntN(len(ips))], nil 112 } 113 114 // LookupIPv4 request with TypeA 115 func (r *Resolver) LookupIPv4(ctx context.Context, host string) ([]netip.Addr, error) { 116 return r.lookupIP(ctx, host, D.TypeA) 117 } 118 119 // ResolveIPv4 request with TypeA 120 func (r *Resolver) ResolveIPv4(host string) (ip netip.Addr, err error) { 121 ips, err := r.lookupIP(context.Background(), host, D.TypeA) 122 if err != nil { 123 return netip.Addr{}, err 124 } else if len(ips) == 0 { 125 return netip.Addr{}, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host) 126 } 127 return ips[rand.IntN(len(ips))], nil 128 } 129 130 // LookupIPv6 request with TypeAAAA 131 func (r *Resolver) LookupIPv6(ctx context.Context, host string) ([]netip.Addr, error) { 132 return r.lookupIP(ctx, host, D.TypeAAAA) 133 } 134 135 // ResolveIPv6 request with TypeAAAA 136 func (r *Resolver) ResolveIPv6(host string) (ip netip.Addr, err error) { 137 ips, err := r.lookupIP(context.Background(), host, D.TypeAAAA) 138 if err != nil { 139 return netip.Addr{}, err 140 } else if len(ips) == 0 { 141 return netip.Addr{}, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host) 142 } 143 return ips[rand.IntN(len(ips))], nil 144 } 145 146 func (r *Resolver) shouldIPFallback(ip netip.Addr) bool { 147 for _, filter := range r.fallbackIPFilters { 148 if filter.Match(ip) { 149 return true 150 } 151 } 152 return false 153 } 154 155 // Exchange a batch of dns request, and it uses cache 156 func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, source string, err error) { 157 return r.ExchangeContext(context.Background(), m) 158 } 159 160 // ExchangeContext a batch of dns request with context.Context, and it uses cache 161 func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, source string, err error) { 162 if len(m.Question) == 0 { 163 return nil, "", errors.New("should have one question at least") 164 } 165 166 var ( 167 q = m.Question[0] 168 key = genMsgCacheKey(ctx, q) 169 ) 170 171 cacheM, expireTime, hit := r.lruCache.GetWithExpire(key) 172 if hit && time.Now().Before(expireTime) { 173 msg1 := cacheM.Copy() 174 msg = msg1.Msg 175 source = msg1.Source 176 setMsgMaxTTL(msg, uint32(time.Until(expireTime).Seconds())) 177 return 178 } 179 msg1, err := r.exchangeWithoutCache(ctx, m, q, key, true) 180 if err != nil { 181 return nil, "", err 182 } 183 return msg1.Msg, msg1.Source, nil 184 } 185 186 // ExchangeContextWithoutCache a batch of dns request with context.Context 187 func (r *Resolver) ExchangeContextWithoutCache(ctx context.Context, m *D.Msg) (msg *D.Msg, source string, err error) { 188 if len(m.Question) == 0 { 189 return nil, "", errors.New("should have one question at least") 190 } 191 192 var ( 193 q = m.Question[0] 194 key = genMsgCacheKey(ctx, q) 195 ) 196 197 msg1, err := r.exchangeWithoutCache(ctx, m, q, key, false) 198 if err != nil { 199 return nil, "", err 200 } 201 return msg1.Msg, msg1.Source, nil 202 } 203 204 // exchangeWithoutCache a batch of dns request, and it does NOT GET from cache 205 func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg, q D.Question, key string, cache bool) (msg *rMsg, err error) { 206 domain := strings.TrimRight(q.Name, ".") 207 ret, err, shared := r.group.Do(key, func() (res any, err error) { 208 defer func() { 209 if err != nil || !cache { 210 return 211 } 212 213 msg1 := res.(*rMsg) 214 215 // OPT RRs MUST NOT be cached, forwarded, or stored in or loaded from master files. 216 msg1.Msg.Extra = lo.Filter(msg1.Msg.Extra, func(rr D.RR, index int) bool { 217 return rr.Header().Rrtype != D.TypeOPT 218 }) 219 220 // skip dns cache for acme challenge 221 if q.Qtype == D.TypeTXT && strings.HasPrefix(q.Name, "_acme-challenge.") { 222 log.Debug(). 223 Str("source", msg1.Source). 224 Str("qType", D.Type(q.Qtype).String()). 225 Str("name", q.Name). 226 Msg("[DNS] dns cache ignored because of acme challenge") 227 return 228 } 229 230 if resolver.IsProxyServer(ctx) { 231 // reset proxy server ip cache expire time to at least 20 minutes 232 sec := max(minTTL(msg1.Msg.Answer), 1200) 233 putMsgToCacheWithExpire(r.lruCache, key, msg1, sec) 234 return 235 } 236 237 if msg1.Msg.Rcode == D.RcodeNameError { // Non-Existent Domain 238 setTTL(msg1.Msg.Ns, 600, true) 239 } 240 241 putMsgToCache(r.lruCache, key, msg1) 242 }() 243 244 isIPReq := isIPRequest(q) 245 if isIPReq { 246 return r.ipExchange(ctx, m, domain) 247 } 248 249 var rst *result 250 if r.remote != nil && resolver.IsRemote(ctx) { 251 rst = r.exchangePolicyCombine(ctx, r.remote, m, domain) 252 } else if r.proxyServer != nil && resolver.IsProxyServer(ctx) { 253 rst = r.exchangePolicyCombine(ctx, r.proxyServer, m, domain) 254 } else { 255 rst = r.exchangePolicyCombine(ctx, r.main, m, domain) 256 } 257 return rst.Msg, rst.Error 258 }) 259 260 if err == nil { 261 msg = ret.(*rMsg) 262 if shared { 263 msg = msg.Copy() 264 } 265 } 266 267 return 268 } 269 270 func (r *Resolver) matchPolicy(domain string) ([]dnsClient, bool) { 271 if r.policy == nil || domain == "" { 272 return nil, false 273 } 274 275 record := r.policy.Search(domain) 276 if record == nil { 277 return nil, false 278 } 279 280 return record.Data.GetData(), true 281 } 282 283 func (r *Resolver) exchangePolicyCombine(ctx context.Context, clients []dnsClient, m *D.Msg, domain string) *result { 284 timeout := resolver.DefaultDNSTimeout 285 if resolver.IsRemote(ctx) { 286 timeout = proxyTimeout 287 } 288 289 res := new(result) 290 policyClients, match := r.matchPolicy(domain) 291 if !match { 292 ctx1, cancel := context.WithTimeout(resolver.CopyCtxValues(ctx), timeout) 293 defer cancel() 294 res.Msg, res.Error = batchExchange(ctx1, clients, m) 295 return res 296 } 297 298 isLan := lo.SomeBy(policyClients, func(c dnsClient) bool { 299 return c.IsLan() 300 }) 301 302 if !isLan { 303 ctx1, cancel := context.WithTimeout(resolver.CopyCtxValues(ctx), timeout) 304 defer cancel() 305 res.Msg, res.Error = batchExchange(ctx1, policyClients, m) 306 res.Policy = true 307 return res 308 } 309 310 var ( 311 res1, res2 *result 312 done1 = atomic.NewBool(false) 313 wg = sync.WaitGroup{} 314 ) 315 316 wg.Add(2) 317 318 ctx1, cancel1 := context.WithTimeout(resolver.CopyCtxValues(ctx), resolver.DefaultDNSTimeout) 319 defer cancel1() 320 321 ctx2, cancel2 := context.WithTimeout(resolver.CopyCtxValues(ctx), timeout) 322 defer cancel2() 323 324 go func() { 325 msg, err := batchExchange(ctx1, policyClients, m) 326 res1 = &result{Msg: msg, Error: err, Policy: true} 327 done1.Store(true) 328 wg.Done() 329 if err == nil { 330 cancel2() // no need to wait for others 331 } 332 }() 333 334 go func() { 335 msg, err := batchExchange(ctx2, clients, m) 336 res2 = &result{Msg: msg, Error: err} 337 wg.Done() 338 if err == nil && !done1.Load() { 339 // if others done before lan policy, then wait maximum 50ms for lan policy 340 for i := 0; i < 10; i++ { 341 time.Sleep(5 * time.Millisecond) 342 if done1.Load() { // check for every 5ms 343 return 344 } 345 } 346 cancel1() 347 } 348 }() 349 350 wg.Wait() 351 352 if res1.Error == nil { 353 res = res1 354 } else { 355 res = res2 356 } 357 358 if res.Error == nil { 359 res.Msg.Lan = true 360 setMsgMaxTTL(res.Msg.Msg, 10) // reset ttl to maximum 10 seconds for lan policy 361 } 362 return res 363 } 364 365 func (r *Resolver) shouldOnlyQueryFallback(domain string) bool { 366 if r.fallback == nil || r.fallbackDomainFilters == nil || domain == "" { 367 return false 368 } 369 370 for _, df := range r.fallbackDomainFilters { 371 if df.Match(domain) { 372 return true 373 } 374 } 375 376 return false 377 } 378 379 func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg, domain string) (msg *rMsg, err error) { 380 if r.remote != nil && resolver.IsRemote(ctx) { 381 res := r.exchangePolicyCombine(ctx, r.remote, m, domain) 382 return res.Msg, res.Error 383 } 384 385 if r.proxyServer != nil && resolver.IsProxyServer(ctx) { 386 res := r.exchangePolicyCombine(ctx, r.proxyServer, m, domain) 387 return res.Msg, res.Error 388 } 389 390 if r.shouldOnlyQueryFallback(domain) { 391 res := r.exchangePolicyCombine(ctx, r.fallback, m, domain) 392 return res.Msg, res.Error 393 } 394 395 res := r.exchangePolicyCombine(ctx, r.main, m, domain) 396 msg, err = res.Msg, res.Error 397 398 if res.Policy { // directly return if from policy servers 399 return 400 } 401 402 if r.fallback == nil { // directly return if no fallback servers are available 403 return 404 } 405 406 if err == nil { 407 if ips := msgToIP(msg.Msg); len(ips) != 0 { 408 if lo.EveryBy(ips, func(ip netip.Addr) bool { 409 return !r.shouldIPFallback(ip) 410 }) { 411 // no need to wait for fallback result 412 return 413 } 414 } 415 } 416 417 res = r.exchangePolicyCombine(ctx, r.fallback, m, domain) 418 msg, err = res.Msg, res.Error 419 return 420 } 421 422 func (r *Resolver) lookupIP(ctx context.Context, host string, dnsType uint16) ([]netip.Addr, error) { 423 ip, err := netip.ParseAddr(host) 424 if err == nil { 425 if dnsType != D.TypeAAAA { 426 ip = ip.Unmap() 427 } 428 isIPv4 := ip.Is4() 429 if dnsType == D.TypeAAAA && !isIPv4 { 430 return []netip.Addr{ip}, nil 431 } else if dnsType == D.TypeA && isIPv4 { 432 return []netip.Addr{ip}, nil 433 } else { 434 return nil, resolver.ErrIPVersion 435 } 436 } 437 438 query := &D.Msg{} 439 query.SetQuestion(D.Fqdn(host), dnsType) 440 441 msg, _, err := r.ExchangeContext(ctx, query) 442 if err != nil { 443 return nil, err 444 } 445 446 ips := msgToIP(msg) 447 if len(ips) != 0 { 448 return ips, nil 449 } else if len(r.searchDomains) == 0 { 450 return nil, resolver.ErrIPNotFound 451 } 452 453 for _, domain := range r.searchDomains { 454 q := &D.Msg{} 455 q.SetQuestion(D.Fqdn(fmt.Sprintf("%s.%s", host, domain)), dnsType) 456 msg1, _, err1 := r.ExchangeContext(ctx, q) 457 if err1 != nil { 458 return nil, err1 459 } 460 ips1 := msgToIP(msg1) 461 if len(ips1) != 0 { 462 return ips1, nil 463 } 464 } 465 466 return nil, resolver.ErrIPNotFound 467 } 468 469 func (r *Resolver) RemoveCache(host string) { 470 q := D.Question{Name: D.Fqdn(host), Qtype: D.TypeA, Qclass: D.ClassINET} 471 r.lruCache.Delete(genMsgCacheKey(context.Background(), q)) 472 q.Qtype = D.TypeAAAA 473 r.lruCache.Delete(genMsgCacheKey(context.Background(), q)) 474 } 475 476 type NameServer struct { 477 Net string 478 Addr string 479 Interface string 480 Proxy string 481 IsDHCP bool 482 } 483 484 type FallbackFilter struct { 485 GeoIP bool 486 GeoIPCode string 487 IPCIDR []*netip.Prefix 488 Domain []string 489 GeoSite []*router.DomainMatcher 490 } 491 492 type Config struct { 493 Main, Fallback []NameServer 494 Default []NameServer 495 ProxyServer []NameServer 496 Remote []NameServer 497 IPv6 bool 498 EnhancedMode C.DNSMode 499 FallbackFilter FallbackFilter 500 Pool *fakeip.Pool 501 Hosts *trie.DomainTrie[netip.Addr] 502 Policy map[string]NameServer 503 SearchDomains []string 504 } 505 506 func NewResolver(config Config) *Resolver { 507 defaultResolver := &Resolver{ 508 main: transform(config.Default, nil), 509 lruCache: cache.New[string, *rMsg]( 510 cache.WithSize[string, *rMsg](128), 511 cache.WithStale[string, *rMsg](true), 512 ), 513 } 514 515 r := &Resolver{ 516 ipv6: config.IPv6, 517 main: transform(config.Main, defaultResolver), 518 lruCache: cache.New[string, *rMsg]( 519 cache.WithSize[string, *rMsg](10240), 520 cache.WithStale[string, *rMsg](true), 521 ), 522 hosts: config.Hosts, 523 searchDomains: config.SearchDomains, 524 } 525 526 if len(config.Fallback) != 0 { 527 r.fallback = transform(config.Fallback, defaultResolver) 528 } 529 530 if len(config.ProxyServer) != 0 { 531 r.proxyServer = transform(config.ProxyServer, defaultResolver) 532 } 533 534 if len(config.Remote) != 0 { 535 remotes := lo.Map(config.Remote, func(item NameServer, _ int) NameServer { 536 item.Proxy = "remote-resolver" 537 return item 538 }) 539 r.remote = transform(remotes, defaultResolver) 540 } 541 542 if len(config.Policy) != 0 { 543 r.policy = trie.New[*Policy]() 544 for domain, nameserver := range config.Policy { 545 _ = r.policy.Insert(domain, NewPolicy(transform([]NameServer{nameserver}, defaultResolver))) 546 } 547 } 548 549 var fallbackIPFilters []fallbackIPFilter 550 if config.FallbackFilter.GeoIP { 551 fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{ 552 code: config.FallbackFilter.GeoIPCode, 553 }) 554 } 555 for _, ipnet := range config.FallbackFilter.IPCIDR { 556 fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet}) 557 } 558 r.fallbackIPFilters = fallbackIPFilters 559 560 var fallbackDomainFilters []fallbackDomainFilter 561 if len(config.FallbackFilter.Domain) != 0 { 562 fallbackDomainFilters = append(fallbackDomainFilters, NewDomainFilter(config.FallbackFilter.Domain)) 563 } 564 565 if len(config.FallbackFilter.GeoSite) != 0 { 566 fallbackDomainFilters = append(fallbackDomainFilters, &geoSiteFilter{ 567 matchers: config.FallbackFilter.GeoSite, 568 }) 569 } 570 r.fallbackDomainFilters = fallbackDomainFilters 571 572 return r 573 }