github.com/chwjbn/xclash@v0.2.0/dns/resolver.go (about) 1 package dns 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "math/rand" 8 "net" 9 "strings" 10 "time" 11 12 "github.com/chwjbn/xclash/common/cache" 13 "github.com/chwjbn/xclash/common/picker" 14 "github.com/chwjbn/xclash/component/fakeip" 15 "github.com/chwjbn/xclash/component/resolver" 16 "github.com/chwjbn/xclash/component/trie" 17 C "github.com/chwjbn/xclash/constant" 18 19 D "github.com/miekg/dns" 20 "golang.org/x/sync/singleflight" 21 ) 22 23 type dnsClient interface { 24 Exchange(m *D.Msg) (msg *D.Msg, err error) 25 ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) 26 } 27 28 type result struct { 29 Msg *D.Msg 30 Error error 31 } 32 33 type Resolver struct { 34 ipv6 bool 35 hosts *trie.DomainTrie 36 main []dnsClient 37 fallback []dnsClient 38 fallbackDomainFilters []fallbackDomainFilter 39 fallbackIPFilters []fallbackIPFilter 40 group singleflight.Group 41 lruCache *cache.LruCache 42 policy *trie.DomainTrie 43 } 44 45 // ResolveIP request with TypeA and TypeAAAA, priority return TypeA 46 func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { 47 ch := make(chan net.IP, 1) 48 go func() { 49 defer close(ch) 50 ip, err := r.resolveIP(host, D.TypeAAAA) 51 if err != nil { 52 return 53 } 54 ch <- ip 55 }() 56 57 ip, err = r.resolveIP(host, D.TypeA) 58 if err == nil { 59 return 60 } 61 62 ip, open := <-ch 63 if !open { 64 return nil, resolver.ErrIPNotFound 65 } 66 67 return ip, nil 68 } 69 70 // ResolveIPv4 request with TypeA 71 func (r *Resolver) ResolveIPv4(host string) (ip net.IP, err error) { 72 return r.resolveIP(host, D.TypeA) 73 } 74 75 // ResolveIPv6 request with TypeAAAA 76 func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) { 77 return r.resolveIP(host, D.TypeAAAA) 78 } 79 80 func (r *Resolver) shouldIPFallback(ip net.IP) bool { 81 for _, filter := range r.fallbackIPFilters { 82 if filter.Match(ip) { 83 return true 84 } 85 } 86 return false 87 } 88 89 // Exchange a batch of dns request, and it use cache 90 func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) { 91 return r.ExchangeContext(context.Background(), m) 92 } 93 94 // ExchangeContext a batch of dns request with context.Context, and it use cache 95 func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { 96 if len(m.Question) == 0 { 97 return nil, errors.New("should have one question at least") 98 } 99 100 q := m.Question[0] 101 cache, expireTime, hit := r.lruCache.GetWithExpire(q.String()) 102 if hit { 103 now := time.Now() 104 msg = cache.(*D.Msg).Copy() 105 if expireTime.Before(now) { 106 setMsgTTL(msg, uint32(1)) // Continue fetch 107 go r.exchangeWithoutCache(ctx, m) 108 } else { 109 setMsgTTL(msg, uint32(time.Until(expireTime).Seconds())) 110 } 111 return 112 } 113 return r.exchangeWithoutCache(ctx, m) 114 } 115 116 // ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache 117 func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { 118 q := m.Question[0] 119 120 ret, err, shared := r.group.Do(q.String(), func() (result interface{}, err error) { 121 defer func() { 122 if err != nil { 123 return 124 } 125 126 msg := result.(*D.Msg) 127 128 putMsgToCache(r.lruCache, q.String(), msg) 129 }() 130 131 isIPReq := isIPRequest(q) 132 if isIPReq { 133 return r.ipExchange(ctx, m) 134 } 135 136 if matched := r.matchPolicy(m); len(matched) != 0 { 137 return r.batchExchange(ctx, matched, m) 138 } 139 return r.batchExchange(ctx, r.main, m) 140 }) 141 142 if err == nil { 143 msg = ret.(*D.Msg) 144 if shared { 145 msg = msg.Copy() 146 } 147 } 148 149 return 150 } 151 152 func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { 153 fast, ctx := picker.WithTimeout(ctx, resolver.DefaultDNSTimeout) 154 for _, client := range clients { 155 r := client 156 fast.Go(func() (interface{}, error) { 157 m, err := r.ExchangeContext(ctx, m) 158 if err != nil { 159 return nil, err 160 } else if m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused { 161 return nil, errors.New("server failure") 162 } 163 return m, nil 164 }) 165 } 166 167 elm := fast.Wait() 168 if elm == nil { 169 err := errors.New("all DNS requests failed") 170 if fErr := fast.Error(); fErr != nil { 171 err = fmt.Errorf("%w, first error: %s", err, fErr.Error()) 172 } 173 return nil, err 174 } 175 176 msg = elm.(*D.Msg) 177 return 178 } 179 180 func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { 181 if r.policy == nil { 182 return nil 183 } 184 185 domain := r.msgToDomain(m) 186 if domain == "" { 187 return nil 188 } 189 190 record := r.policy.Search(domain) 191 if record == nil { 192 return nil 193 } 194 195 return record.Data.([]dnsClient) 196 } 197 198 func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { 199 if r.fallback == nil || len(r.fallbackDomainFilters) == 0 { 200 return false 201 } 202 203 domain := r.msgToDomain(m) 204 205 if domain == "" { 206 return false 207 } 208 209 for _, df := range r.fallbackDomainFilters { 210 if df.Match(domain) { 211 return true 212 } 213 } 214 215 return false 216 } 217 218 func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { 219 if matched := r.matchPolicy(m); len(matched) != 0 { 220 res := <-r.asyncExchange(ctx, matched, m) 221 return res.Msg, res.Error 222 } 223 224 onlyFallback := r.shouldOnlyQueryFallback(m) 225 226 if onlyFallback { 227 res := <-r.asyncExchange(ctx, r.fallback, m) 228 return res.Msg, res.Error 229 } 230 231 msgCh := r.asyncExchange(ctx, r.main, m) 232 233 if r.fallback == nil { // directly return if no fallback servers are available 234 res := <-msgCh 235 msg, err = res.Msg, res.Error 236 return 237 } 238 239 fallbackMsg := r.asyncExchange(ctx, r.fallback, m) 240 res := <-msgCh 241 if res.Error == nil { 242 if ips := msgToIP(res.Msg); len(ips) != 0 { 243 if !r.shouldIPFallback(ips[0]) { 244 msg = res.Msg // no need to wait for fallback result 245 err = res.Error 246 return msg, err 247 } 248 } 249 } 250 251 res = <-fallbackMsg 252 msg, err = res.Msg, res.Error 253 return 254 } 255 256 func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) { 257 ip = net.ParseIP(host) 258 if ip != nil { 259 isIPv4 := ip.To4() != nil 260 if dnsType == D.TypeAAAA && !isIPv4 { 261 return ip, nil 262 } else if dnsType == D.TypeA && isIPv4 { 263 return ip, nil 264 } else { 265 return nil, resolver.ErrIPVersion 266 } 267 } 268 269 query := &D.Msg{} 270 query.SetQuestion(D.Fqdn(host), dnsType) 271 272 msg, err := r.Exchange(query) 273 if err != nil { 274 return nil, err 275 } 276 277 ips := msgToIP(msg) 278 ipLength := len(ips) 279 if ipLength == 0 { 280 return nil, resolver.ErrIPNotFound 281 } 282 283 ip = ips[rand.Intn(ipLength)] 284 return 285 } 286 287 func (r *Resolver) msgToDomain(msg *D.Msg) string { 288 if len(msg.Question) > 0 { 289 return strings.TrimRight(msg.Question[0].Name, ".") 290 } 291 292 return "" 293 } 294 295 func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D.Msg) <-chan *result { 296 ch := make(chan *result, 1) 297 go func() { 298 res, err := r.batchExchange(ctx, client, msg) 299 ch <- &result{Msg: res, Error: err} 300 }() 301 return ch 302 } 303 304 type NameServer struct { 305 Net string 306 Addr string 307 Interface string 308 } 309 310 type FallbackFilter struct { 311 GeoIP bool 312 GeoIPCode string 313 IPCIDR []*net.IPNet 314 Domain []string 315 } 316 317 type Config struct { 318 Main, Fallback []NameServer 319 Default []NameServer 320 IPv6 bool 321 EnhancedMode C.DNSMode 322 FallbackFilter FallbackFilter 323 Pool *fakeip.Pool 324 Hosts *trie.DomainTrie 325 Policy map[string]NameServer 326 } 327 328 func NewResolver(config Config) *Resolver { 329 defaultResolver := &Resolver{ 330 main: transform(config.Default, nil), 331 lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), 332 } 333 334 r := &Resolver{ 335 ipv6: config.IPv6, 336 main: transform(config.Main, defaultResolver), 337 lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), 338 hosts: config.Hosts, 339 } 340 341 if len(config.Fallback) != 0 { 342 r.fallback = transform(config.Fallback, defaultResolver) 343 } 344 345 if len(config.Policy) != 0 { 346 r.policy = trie.New() 347 for domain, nameserver := range config.Policy { 348 r.policy.Insert(domain, transform([]NameServer{nameserver}, defaultResolver)) 349 } 350 } 351 352 fallbackIPFilters := []fallbackIPFilter{} 353 if config.FallbackFilter.GeoIP { 354 fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{ 355 code: config.FallbackFilter.GeoIPCode, 356 }) 357 } 358 for _, ipnet := range config.FallbackFilter.IPCIDR { 359 fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet}) 360 } 361 r.fallbackIPFilters = fallbackIPFilters 362 363 if len(config.FallbackFilter.Domain) != 0 { 364 fallbackDomainFilters := []fallbackDomainFilter{NewDomainFilter(config.FallbackFilter.Domain)} 365 r.fallbackDomainFilters = fallbackDomainFilters 366 } 367 368 return r 369 }