github.com/kelleygo/clashcore@v1.0.2/dns/resolver.go (about) 1 package dns 2 3 import ( 4 "context" 5 "errors" 6 "net/netip" 7 "strings" 8 "time" 9 10 "github.com/kelleygo/clashcore/common/arc" 11 "github.com/kelleygo/clashcore/common/lru" 12 "github.com/kelleygo/clashcore/component/fakeip" 13 "github.com/kelleygo/clashcore/component/geodata/router" 14 "github.com/kelleygo/clashcore/component/resolver" 15 "github.com/kelleygo/clashcore/component/trie" 16 C "github.com/kelleygo/clashcore/constant" 17 "github.com/kelleygo/clashcore/constant/provider" 18 "github.com/kelleygo/clashcore/log" 19 20 D "github.com/miekg/dns" 21 "github.com/samber/lo" 22 orderedmap "github.com/wk8/go-ordered-map/v2" 23 "golang.org/x/exp/maps" 24 "golang.org/x/sync/singleflight" 25 ) 26 27 type dnsClient interface { 28 ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) 29 Address() string 30 } 31 32 type dnsCache interface { 33 GetWithExpire(key string) (*D.Msg, time.Time, bool) 34 SetWithExpire(key string, value *D.Msg, expire time.Time) 35 } 36 37 type result struct { 38 Msg *D.Msg 39 Error error 40 } 41 42 type Resolver struct { 43 ipv6 bool 44 ipv6Timeout time.Duration 45 hosts *trie.DomainTrie[resolver.HostValue] 46 main []dnsClient 47 fallback []dnsClient 48 fallbackDomainFilters []fallbackDomainFilter 49 fallbackIPFilters []fallbackIPFilter 50 group singleflight.Group 51 cache dnsCache 52 policy []dnsPolicy 53 proxyServer []dnsClient 54 } 55 56 func (r *Resolver) LookupIPPrimaryIPv4(ctx context.Context, host string) (ips []netip.Addr, err error) { 57 ch := make(chan []netip.Addr, 1) 58 go func() { 59 defer close(ch) 60 ip, err := r.lookupIP(ctx, host, D.TypeAAAA) 61 if err != nil { 62 return 63 } 64 ch <- ip 65 }() 66 67 ips, err = r.lookupIP(ctx, host, D.TypeA) 68 if err == nil { 69 return 70 } 71 72 ip, open := <-ch 73 if !open { 74 return nil, resolver.ErrIPNotFound 75 } 76 77 return ip, nil 78 } 79 80 func (r *Resolver) LookupIP(ctx context.Context, host string) (ips []netip.Addr, err error) { 81 ch := make(chan []netip.Addr, 1) 82 go func() { 83 defer close(ch) 84 ip, err := r.lookupIP(ctx, host, D.TypeAAAA) 85 if err != nil { 86 return 87 } 88 89 ch <- ip 90 }() 91 92 ips, err = r.lookupIP(ctx, host, D.TypeA) 93 var waitIPv6 *time.Timer 94 if r != nil && r.ipv6Timeout > 0 { 95 waitIPv6 = time.NewTimer(r.ipv6Timeout) 96 } else { 97 waitIPv6 = time.NewTimer(100 * time.Millisecond) 98 } 99 defer waitIPv6.Stop() 100 select { 101 case ipv6s, open := <-ch: 102 if !open && err != nil { 103 return nil, resolver.ErrIPNotFound 104 } 105 ips = append(ips, ipv6s...) 106 case <-waitIPv6.C: 107 // wait ipv6 result 108 } 109 110 return ips, nil 111 } 112 113 // LookupIPv4 request with TypeA 114 func (r *Resolver) LookupIPv4(ctx context.Context, host string) ([]netip.Addr, error) { 115 return r.lookupIP(ctx, host, D.TypeA) 116 } 117 118 // LookupIPv6 request with TypeAAAA 119 func (r *Resolver) LookupIPv6(ctx context.Context, host string) ([]netip.Addr, error) { 120 return r.lookupIP(ctx, host, D.TypeAAAA) 121 } 122 123 func (r *Resolver) shouldIPFallback(ip netip.Addr) bool { 124 for _, filter := range r.fallbackIPFilters { 125 if filter.Match(ip) { 126 return true 127 } 128 } 129 return false 130 } 131 132 // ExchangeContext a batch of dns request with context.Context, and it use cache 133 func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { 134 if len(m.Question) == 0 { 135 return nil, errors.New("should have one question at least") 136 } 137 continueFetch := false 138 defer func() { 139 if continueFetch || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { 140 go func() { 141 ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) 142 defer cancel() 143 _, _ = r.exchangeWithoutCache(ctx, m) // ignore result, just for putMsgToCache 144 }() 145 } 146 }() 147 148 q := m.Question[0] 149 cacheM, expireTime, hit := r.cache.GetWithExpire(q.String()) 150 if hit { 151 log.Debugln("[DNS] cache hit for %s, expire at %s", q.Name, expireTime.Format("2006-01-02 15:04:05")) 152 now := time.Now() 153 msg = cacheM.Copy() 154 if expireTime.Before(now) { 155 setMsgTTL(msg, uint32(1)) // Continue fetch 156 continueFetch = true 157 } else { 158 // updating TTL by subtracting common delta time from each DNS record 159 updateMsgTTL(msg, uint32(time.Until(expireTime).Seconds())) 160 } 161 return 162 } 163 return r.exchangeWithoutCache(ctx, m) 164 } 165 166 // ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache 167 func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { 168 q := m.Question[0] 169 170 retryNum := 0 171 retryMax := 3 172 fn := func() (result any, err error) { 173 ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) // reset timeout in singleflight 174 defer cancel() 175 cache := false 176 177 defer func() { 178 if err != nil { 179 result = retryNum 180 retryNum++ 181 return 182 } 183 184 msg := result.(*D.Msg) 185 186 if cache { 187 // OPT RRs MUST NOT be cached, forwarded, or stored in or loaded from master files. 188 msg.Extra = lo.Filter(msg.Extra, func(rr D.RR, index int) bool { 189 return rr.Header().Rrtype != D.TypeOPT 190 }) 191 putMsgToCache(r.cache, q.String(), q, msg) 192 } 193 }() 194 195 isIPReq := isIPRequest(q) 196 if isIPReq { 197 cache = true 198 return r.ipExchange(ctx, m) 199 } 200 201 if matched := r.matchPolicy(m); len(matched) != 0 { 202 result, cache, err = batchExchange(ctx, matched, m) 203 return 204 } 205 result, cache, err = batchExchange(ctx, r.main, m) 206 return 207 } 208 209 ch := r.group.DoChan(q.String(), fn) 210 211 var result singleflight.Result 212 213 select { 214 case result = <-ch: 215 break 216 case <-ctx.Done(): 217 select { 218 case result = <-ch: // maybe ctxDone and chFinish in same time, get DoChan's result as much as possible 219 break 220 default: 221 go func() { // start a retrying monitor in background 222 result := <-ch 223 ret, err, shared := result.Val, result.Err, result.Shared 224 if err != nil && !shared && ret.(int) < retryMax { // retry 225 r.group.DoChan(q.String(), fn) 226 } 227 }() 228 return nil, ctx.Err() 229 } 230 } 231 232 ret, err, shared := result.Val, result.Err, result.Shared 233 if err != nil && !shared && ret.(int) < retryMax { // retry 234 r.group.DoChan(q.String(), fn) 235 } 236 237 if err == nil { 238 msg = ret.(*D.Msg) 239 if shared { 240 msg = msg.Copy() 241 } 242 } 243 244 return 245 } 246 247 func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { 248 if r.policy == nil { 249 return nil 250 } 251 252 domain := msgToDomain(m) 253 if domain == "" { 254 return nil 255 } 256 257 for _, policy := range r.policy { 258 if dnsClients := policy.Match(domain); len(dnsClients) > 0 { 259 return dnsClients 260 } 261 } 262 return nil 263 } 264 265 func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { 266 if r.fallback == nil || len(r.fallbackDomainFilters) == 0 { 267 return false 268 } 269 270 domain := msgToDomain(m) 271 272 if domain == "" { 273 return false 274 } 275 276 for _, df := range r.fallbackDomainFilters { 277 if df.Match(domain) { 278 return true 279 } 280 } 281 282 return false 283 } 284 285 func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { 286 if matched := r.matchPolicy(m); len(matched) != 0 { 287 res := <-r.asyncExchange(ctx, matched, m) 288 return res.Msg, res.Error 289 } 290 291 onlyFallback := r.shouldOnlyQueryFallback(m) 292 293 if onlyFallback { 294 res := <-r.asyncExchange(ctx, r.fallback, m) 295 return res.Msg, res.Error 296 } 297 298 msgCh := r.asyncExchange(ctx, r.main, m) 299 300 if r.fallback == nil || len(r.fallback) == 0 { // directly return if no fallback servers are available 301 res := <-msgCh 302 msg, err = res.Msg, res.Error 303 return 304 } 305 306 res := <-msgCh 307 if res.Error == nil { 308 if ips := msgToIP(res.Msg); len(ips) != 0 { 309 shouldNotFallback := lo.EveryBy(ips, func(ip netip.Addr) bool { 310 return !r.shouldIPFallback(ip) 311 }) 312 if shouldNotFallback { 313 msg, err = res.Msg, res.Error // no need to wait for fallback result 314 return 315 } 316 } 317 } 318 319 res = <-r.asyncExchange(ctx, r.fallback, m) 320 msg, err = res.Msg, res.Error 321 return 322 } 323 324 func (r *Resolver) lookupIP(ctx context.Context, host string, dnsType uint16) (ips []netip.Addr, err error) { 325 ip, err := netip.ParseAddr(host) 326 if err == nil { 327 isIPv4 := ip.Is4() || ip.Is4In6() 328 if dnsType == D.TypeAAAA && !isIPv4 { 329 return []netip.Addr{ip}, nil 330 } else if dnsType == D.TypeA && isIPv4 { 331 return []netip.Addr{ip}, nil 332 } else { 333 return []netip.Addr{}, resolver.ErrIPVersion 334 } 335 } 336 337 query := &D.Msg{} 338 query.SetQuestion(D.Fqdn(host), dnsType) 339 340 msg, err := r.ExchangeContext(ctx, query) 341 if err != nil { 342 return []netip.Addr{}, err 343 } 344 345 ips = msgToIP(msg) 346 ipLength := len(ips) 347 if ipLength == 0 { 348 return []netip.Addr{}, resolver.ErrIPNotFound 349 } 350 351 return 352 } 353 354 func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D.Msg) <-chan *result { 355 ch := make(chan *result, 1) 356 go func() { 357 res, _, err := batchExchange(ctx, client, msg) 358 ch <- &result{Msg: res, Error: err} 359 }() 360 return ch 361 } 362 363 // Invalid return this resolver can or can't be used 364 func (r *Resolver) Invalid() bool { 365 if r == nil { 366 return false 367 } 368 return len(r.main) > 0 369 } 370 371 type NameServer struct { 372 Net string 373 Addr string 374 Interface string 375 ProxyAdapter C.ProxyAdapter 376 ProxyName string 377 Params map[string]string 378 PreferH3 bool 379 } 380 381 func (ns NameServer) Equal(ns2 NameServer) bool { 382 defer func() { 383 // C.ProxyAdapter compare maybe panic, just ignore 384 recover() 385 }() 386 if ns.Net == ns2.Net && 387 ns.Addr == ns2.Addr && 388 ns.Interface == ns2.Interface && 389 ns.ProxyAdapter == ns2.ProxyAdapter && 390 ns.ProxyName == ns2.ProxyName && 391 maps.Equal(ns.Params, ns2.Params) && 392 ns.PreferH3 == ns2.PreferH3 { 393 return true 394 } 395 return false 396 } 397 398 type FallbackFilter struct { 399 GeoIP bool 400 GeoIPCode string 401 IPCIDR []netip.Prefix 402 Domain []string 403 GeoSite []router.DomainMatcher 404 } 405 406 type Config struct { 407 Main, Fallback []NameServer 408 Default []NameServer 409 ProxyServer []NameServer 410 IPv6 bool 411 IPv6Timeout uint 412 EnhancedMode C.DNSMode 413 FallbackFilter FallbackFilter 414 Pool *fakeip.Pool 415 Hosts *trie.DomainTrie[resolver.HostValue] 416 Policy *orderedmap.OrderedMap[string, []NameServer] 417 RuleProviders map[string]provider.RuleProvider 418 CacheAlgorithm string 419 } 420 421 func NewResolver(config Config) *Resolver { 422 var cache dnsCache 423 if config.CacheAlgorithm == "lru" { 424 cache = lru.New(lru.WithSize[string, *D.Msg](4096), lru.WithStale[string, *D.Msg](true)) 425 } else { 426 cache = arc.New(arc.WithSize[string, *D.Msg](4096)) 427 } 428 defaultResolver := &Resolver{ 429 main: transform(config.Default, nil), 430 cache: cache, 431 ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond, 432 } 433 434 var nameServerCache []struct { 435 NameServer 436 dnsClient 437 } 438 cacheTransform := func(nameserver []NameServer) (result []dnsClient) { 439 LOOP: 440 for _, ns := range nameserver { 441 for _, nsc := range nameServerCache { 442 if nsc.NameServer.Equal(ns) { 443 result = append(result, nsc.dnsClient) 444 continue LOOP 445 } 446 } 447 // not in cache 448 dc := transform([]NameServer{ns}, defaultResolver) 449 if len(dc) > 0 { 450 dc := dc[0] 451 nameServerCache = append(nameServerCache, struct { 452 NameServer 453 dnsClient 454 }{NameServer: ns, dnsClient: dc}) 455 result = append(result, dc) 456 } 457 } 458 return 459 } 460 461 if config.CacheAlgorithm == "" || config.CacheAlgorithm == "lru" { 462 cache = lru.New(lru.WithSize[string, *D.Msg](4096), lru.WithStale[string, *D.Msg](true)) 463 } else { 464 cache = arc.New(arc.WithSize[string, *D.Msg](4096)) 465 } 466 r := &Resolver{ 467 ipv6: config.IPv6, 468 main: cacheTransform(config.Main), 469 cache: cache, 470 hosts: config.Hosts, 471 ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond, 472 } 473 474 if len(config.Fallback) != 0 { 475 r.fallback = cacheTransform(config.Fallback) 476 } 477 478 if len(config.ProxyServer) != 0 { 479 r.proxyServer = cacheTransform(config.ProxyServer) 480 } 481 482 if config.Policy.Len() != 0 { 483 r.policy = make([]dnsPolicy, 0) 484 485 var triePolicy *trie.DomainTrie[[]dnsClient] 486 insertPolicy := func(policy dnsPolicy) { 487 if triePolicy != nil { 488 triePolicy.Optimize() 489 r.policy = append(r.policy, domainTriePolicy{triePolicy}) 490 triePolicy = nil 491 } 492 if policy != nil { 493 r.policy = append(r.policy, policy) 494 } 495 } 496 497 for pair := config.Policy.Oldest(); pair != nil; pair = pair.Next() { 498 domain, nameserver := pair.Key, pair.Value 499 500 if temp := strings.Split(domain, ":"); len(temp) == 2 { 501 prefix := temp[0] 502 key := temp[1] 503 switch prefix { 504 case "rule-set": 505 if p, ok := config.RuleProviders[key]; ok { 506 log.Debugln("Adding rule-set policy: %s ", key) 507 insertPolicy(domainSetPolicy{ 508 domainSetProvider: p, 509 dnsClients: cacheTransform(nameserver), 510 }) 511 continue 512 } else { 513 log.Warnln("Can't found ruleset policy: %s", key) 514 } 515 case "geosite": 516 inverse := false 517 if strings.HasPrefix(key, "!") { 518 inverse = true 519 key = key[1:] 520 } 521 log.Debugln("Adding geosite policy: %s inversed %t", key, inverse) 522 matcher, err := NewGeoSite(key) 523 if err != nil { 524 log.Warnln("adding geosite policy %s error: %s", key, err) 525 continue 526 } 527 insertPolicy(geositePolicy{ 528 matcher: matcher, 529 inverse: inverse, 530 dnsClients: cacheTransform(nameserver), 531 }) 532 continue // skip triePolicy new 533 } 534 } 535 if triePolicy == nil { 536 triePolicy = trie.New[[]dnsClient]() 537 } 538 _ = triePolicy.Insert(domain, cacheTransform(nameserver)) 539 } 540 insertPolicy(nil) 541 } 542 543 fallbackIPFilters := []fallbackIPFilter{} 544 if config.FallbackFilter.GeoIP { 545 fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{ 546 code: config.FallbackFilter.GeoIPCode, 547 }) 548 } 549 for _, ipnet := range config.FallbackFilter.IPCIDR { 550 fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet}) 551 } 552 r.fallbackIPFilters = fallbackIPFilters 553 554 fallbackDomainFilters := []fallbackDomainFilter{} 555 if len(config.FallbackFilter.Domain) != 0 { 556 fallbackDomainFilters = append(fallbackDomainFilters, NewDomainFilter(config.FallbackFilter.Domain)) 557 } 558 559 if len(config.FallbackFilter.GeoSite) != 0 { 560 fallbackDomainFilters = append(fallbackDomainFilters, &geoSiteFilter{ 561 matchers: config.FallbackFilter.GeoSite, 562 }) 563 } 564 r.fallbackDomainFilters = fallbackDomainFilters 565 566 return r 567 } 568 569 func NewProxyServerHostResolver(old *Resolver) *Resolver { 570 r := &Resolver{ 571 ipv6: old.ipv6, 572 main: old.proxyServer, 573 cache: old.cache, 574 hosts: old.hosts, 575 ipv6Timeout: old.ipv6Timeout, 576 } 577 return r 578 } 579 580 var ParseNameServer func(servers []string) ([]NameServer, error) // define in config/config.go