github.com/igoogolx/clash@v1.19.8/dns/resolver.go (about) 1 package dns 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "math/rand" 8 "net" 9 "strings" 10 "time" 11 12 "github.com/igoogolx/clash/common/cache" 13 "github.com/igoogolx/clash/component/fakeip" 14 "github.com/igoogolx/clash/component/resolver" 15 "github.com/igoogolx/clash/component/trie" 16 C "github.com/igoogolx/clash/constant" 17 18 D "github.com/miekg/dns" 19 "github.com/samber/lo" 20 "golang.org/x/sync/singleflight" 21 ) 22 23 type dnsClient interface { 24 Exchange(m *D.Msg) (msg *D.Msg, err error) 25 ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) 26 } 27 28 type result struct { 29 Msg *D.Msg 30 Error error 31 } 32 33 type Resolver struct { 34 ipv6 bool 35 hosts *trie.DomainTrie 36 main []dnsClient 37 fallback []dnsClient 38 fallbackDomainFilters []fallbackDomainFilter 39 fallbackIPFilters []fallbackIPFilter 40 group singleflight.Group 41 lruCache *cache.LruCache 42 policy *trie.DomainTrie 43 searchDomains []string 44 disableCache bool 45 } 46 47 // LookupIP request with TypeA and TypeAAAA, priority return TypeA 48 func (r *Resolver) LookupIP(ctx context.Context, host string) (ip []net.IP, err error) { 49 ctx, cancel := context.WithCancel(ctx) 50 defer cancel() 51 52 ch := make(chan []net.IP, 1) 53 54 go func() { 55 defer close(ch) 56 ip, err := r.lookupIP(ctx, host, D.TypeAAAA) 57 if err != nil { 58 return 59 } 60 ch <- ip 61 }() 62 63 ip, err = r.lookupIP(ctx, host, D.TypeA) 64 if err == nil { 65 return 66 } 67 68 ip, open := <-ch 69 if !open { 70 return nil, resolver.ErrIPNotFound 71 } 72 73 return ip, nil 74 } 75 76 // ResolveIP request with TypeA and TypeAAAA, priority return TypeA 77 func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { 78 ips, err := r.LookupIP(context.Background(), host) 79 if err != nil { 80 return nil, err 81 } else if len(ips) == 0 { 82 return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host) 83 } 84 return ips[rand.Intn(len(ips))], nil 85 } 86 87 // LookupIPv4 request with TypeA 88 func (r *Resolver) LookupIPv4(ctx context.Context, host string) ([]net.IP, error) { 89 return r.lookupIP(ctx, host, D.TypeA) 90 } 91 92 // ResolveIPv4 request with TypeA 93 func (r *Resolver) ResolveIPv4(host string) (ip net.IP, err error) { 94 ips, err := r.lookupIP(context.Background(), host, D.TypeA) 95 if err != nil { 96 return nil, err 97 } else if len(ips) == 0 { 98 return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host) 99 } 100 return ips[rand.Intn(len(ips))], nil 101 } 102 103 // LookupIPv6 request with TypeAAAA 104 func (r *Resolver) LookupIPv6(ctx context.Context, host string) ([]net.IP, error) { 105 return r.lookupIP(ctx, host, D.TypeAAAA) 106 } 107 108 // ResolveIPv6 request with TypeAAAA 109 func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) { 110 ips, err := r.lookupIP(context.Background(), host, D.TypeAAAA) 111 if err != nil { 112 return nil, err 113 } else if len(ips) == 0 { 114 return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host) 115 } 116 return ips[rand.Intn(len(ips))], nil 117 } 118 119 func (r *Resolver) shouldIPFallback(ip net.IP) bool { 120 for _, filter := range r.fallbackIPFilters { 121 if filter.Match(ip) { 122 return true 123 } 124 } 125 return false 126 } 127 128 // Exchange a batch of dns request, and it use cache 129 func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) { 130 return r.ExchangeContext(context.Background(), m) 131 } 132 133 // ExchangeContext a batch of dns request with context.Context, and it use cache 134 func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { 135 if len(m.Question) == 0 { 136 return nil, errors.New("should have one question at least") 137 } 138 139 q := m.Question[0] 140 cache, expireTime, hit := r.lruCache.GetWithExpire(q.String()) 141 if hit { 142 now := time.Now() 143 msg = cache.(*D.Msg).Copy() 144 if expireTime.Before(now) { 145 setMsgTTL(msg, uint32(1)) // Continue fetch 146 go func() { 147 ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) 148 r.exchangeWithoutCache(ctx, m) 149 cancel() 150 }() 151 } else { 152 // updating TTL by subtracting common delta time from each DNS record 153 updateMsgTTL(msg, uint32(time.Until(expireTime).Seconds())) 154 } 155 return 156 } 157 return r.exchangeWithoutCache(ctx, m) 158 } 159 160 func (r *Resolver) exchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { 161 q := m.Question[0] 162 isIPReq := isIPRequest(q) 163 if isIPReq { 164 return r.ipExchange(ctx, m) 165 } 166 167 if matched := r.matchPolicy(m); len(matched) != 0 { 168 return r.batchExchange(ctx, matched, m) 169 } 170 return r.batchExchange(ctx, r.main, m) 171 } 172 173 // ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache 174 func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { 175 if r.disableCache { 176 msg, err = r.exchange(ctx, m) 177 } else { 178 q := m.Question[0] 179 ret, err, shared := r.group.Do(q.String(), func() (result any, err error) { 180 defer func() { 181 if err != nil { 182 return 183 } 184 185 msg := result.(*D.Msg) 186 // OPT RRs MUST NOT be cached, forwarded, or stored in or loaded from master files. 187 msg.Extra = lo.Filter(msg.Extra, func(rr D.RR, index int) bool { 188 return rr.Header().Rrtype != D.TypeOPT 189 }) 190 putMsgToCache(r.lruCache, q.String(), q, msg) 191 }() 192 return r.exchange(ctx, m) 193 }) 194 if err == nil { 195 msg = ret.(*D.Msg) 196 if shared { 197 msg = msg.Copy() 198 } 199 } 200 } 201 202 return 203 } 204 205 func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { 206 ctx, cancel := context.WithTimeout(ctx, resolver.DefaultDNSTimeout) 207 defer cancel() 208 209 return batchExchange(ctx, clients, m) 210 } 211 212 func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { 213 if r.policy == nil { 214 return nil 215 } 216 217 domain := r.msgToDomain(m) 218 if domain == "" { 219 return nil 220 } 221 222 record := r.policy.Search(domain) 223 if record == nil { 224 return nil 225 } 226 227 return record.Data.([]dnsClient) 228 } 229 230 func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { 231 if r.fallback == nil || len(r.fallbackDomainFilters) == 0 { 232 return false 233 } 234 235 domain := r.msgToDomain(m) 236 237 if domain == "" { 238 return false 239 } 240 241 for _, df := range r.fallbackDomainFilters { 242 if df.Match(domain) { 243 return true 244 } 245 } 246 247 return false 248 } 249 250 func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { 251 if matched := r.matchPolicy(m); len(matched) != 0 { 252 res := <-r.asyncExchange(ctx, matched, m) 253 return res.Msg, res.Error 254 } 255 256 onlyFallback := r.shouldOnlyQueryFallback(m) 257 258 if onlyFallback { 259 res := <-r.asyncExchange(ctx, r.fallback, m) 260 return res.Msg, res.Error 261 } 262 263 msgCh := r.asyncExchange(ctx, r.main, m) 264 265 if r.fallback == nil { // directly return if no fallback servers are available 266 res := <-msgCh 267 msg, err = res.Msg, res.Error 268 return 269 } 270 271 fallbackMsg := r.asyncExchange(ctx, r.fallback, m) 272 res := <-msgCh 273 if res.Error == nil { 274 if ips := msgToIP(res.Msg); len(ips) != 0 { 275 shouldNotFallback := lo.EveryBy(ips, func(ip net.IP) bool { 276 return !r.shouldIPFallback(ip) 277 }) 278 if shouldNotFallback { 279 msg = res.Msg // no need to wait for fallback result 280 err = res.Error 281 return msg, err 282 } 283 } 284 } 285 286 res = <-fallbackMsg 287 msg, err = res.Msg, res.Error 288 return 289 } 290 291 func (r *Resolver) lookupIP(ctx context.Context, host string, dnsType uint16) ([]net.IP, error) { 292 ip := net.ParseIP(host) 293 if ip != nil { 294 ip4 := ip.To4() 295 isIPv4 := ip4 != nil 296 if dnsType == D.TypeAAAA && !isIPv4 { 297 return []net.IP{ip}, nil 298 } else if dnsType == D.TypeA && isIPv4 { 299 return []net.IP{ip4}, nil 300 } else { 301 return nil, resolver.ErrIPVersion 302 } 303 } 304 305 query := &D.Msg{} 306 query.SetQuestion(D.Fqdn(host), dnsType) 307 308 msg, err := r.ExchangeContext(ctx, query) 309 if err != nil { 310 return nil, err 311 } 312 313 ips := msgToIP(msg) 314 if len(ips) != 0 { 315 return ips, nil 316 } else if len(r.searchDomains) == 0 { 317 return nil, resolver.ErrIPNotFound 318 } 319 320 // query provided search domains serially 321 for _, domain := range r.searchDomains { 322 q := &D.Msg{} 323 q.SetQuestion(D.Fqdn(fmt.Sprintf("%s.%s", host, domain)), dnsType) 324 msg, err := r.ExchangeContext(ctx, q) 325 if err != nil { 326 return nil, err 327 } 328 ips := msgToIP(msg) 329 if len(ips) != 0 { 330 return ips, nil 331 } 332 } 333 334 return nil, resolver.ErrIPNotFound 335 } 336 337 func (r *Resolver) msgToDomain(msg *D.Msg) string { 338 if len(msg.Question) > 0 { 339 return strings.TrimRight(msg.Question[0].Name, ".") 340 } 341 342 return "" 343 } 344 345 func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D.Msg) <-chan *result { 346 ch := make(chan *result, 1) 347 go func() { 348 res, err := r.batchExchange(ctx, client, msg) 349 ch <- &result{Msg: res, Error: err} 350 }() 351 return ch 352 } 353 354 type NameServer struct { 355 Net string 356 Addr string 357 Interface string 358 } 359 360 type FallbackFilter struct { 361 GeoIP bool 362 GeoIPCode string 363 IPCIDR []*net.IPNet 364 Domain []string 365 } 366 367 type Config struct { 368 Main, Fallback []NameServer 369 Default []NameServer 370 IPv6 bool 371 EnhancedMode C.DNSMode 372 FallbackFilter FallbackFilter 373 Pool *fakeip.Pool 374 Hosts *trie.DomainTrie 375 Policy map[string]NameServer 376 SearchDomains []string 377 DisableCache bool 378 GetDialer func() (C.Proxy, error) 379 } 380 381 func NewResolver(config Config) *Resolver { 382 383 r := &Resolver{ 384 ipv6: config.IPv6, 385 main: transform(config.Main, config.GetDialer), 386 lruCache: cache.New(cache.WithSize(4096), cache.WithStale(true)), 387 hosts: config.Hosts, 388 searchDomains: config.SearchDomains, 389 disableCache: config.DisableCache, 390 } 391 392 if len(config.Fallback) != 0 { 393 r.fallback = transform(config.Fallback, config.GetDialer) 394 } 395 396 if len(config.Policy) != 0 { 397 r.policy = trie.New() 398 for domain, nameserver := range config.Policy { 399 r.policy.Insert(domain, transform([]NameServer{nameserver}, config.GetDialer)) 400 } 401 } 402 403 fallbackIPFilters := []fallbackIPFilter{} 404 if config.FallbackFilter.GeoIP { 405 fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{ 406 code: config.FallbackFilter.GeoIPCode, 407 }) 408 } 409 for _, ipnet := range config.FallbackFilter.IPCIDR { 410 fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet}) 411 } 412 r.fallbackIPFilters = fallbackIPFilters 413 414 if len(config.FallbackFilter.Domain) != 0 { 415 fallbackDomainFilters := []fallbackDomainFilter{NewDomainFilter(config.FallbackFilter.Domain)} 416 r.fallbackDomainFilters = fallbackDomainFilters 417 } 418 419 return r 420 }