github.com/yaling888/clash@v1.53.0/dns/resolver.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand/v2"
     8  	"net/netip"
     9  	"strings"
    10  	"sync"
    11  	"time"
    12  
    13  	D "github.com/miekg/dns"
    14  	"github.com/phuslu/log"
    15  	"github.com/samber/lo"
    16  	"go.uber.org/atomic"
    17  	"golang.org/x/sync/singleflight"
    18  
    19  	"github.com/yaling888/clash/common/cache"
    20  	"github.com/yaling888/clash/component/fakeip"
    21  	"github.com/yaling888/clash/component/geodata/router"
    22  	"github.com/yaling888/clash/component/resolver"
    23  	"github.com/yaling888/clash/component/trie"
    24  	C "github.com/yaling888/clash/constant"
    25  )
    26  
    27  type dnsClient interface {
    28  	Exchange(m *D.Msg) (msg *rMsg, err error)
    29  	ExchangeContext(ctx context.Context, m *D.Msg) (msg *rMsg, err error)
    30  	IsLan() bool
    31  }
    32  
    33  type result struct {
    34  	Msg    *rMsg
    35  	Error  error
    36  	Policy bool
    37  }
    38  
    39  type rMsg struct {
    40  	Msg    *D.Msg
    41  	Source string
    42  	Lan    bool
    43  }
    44  
    45  func (m *rMsg) Copy() *rMsg {
    46  	m1 := new(rMsg)
    47  	m1.Msg = m.Msg.Copy()
    48  	m1.Source = m.Source
    49  	m1.Lan = m.Lan
    50  	return m1
    51  }
    52  
    53  var _ resolver.Resolver = (*Resolver)(nil)
    54  
    55  type Resolver struct {
    56  	ipv6                  bool
    57  	hosts                 *trie.DomainTrie[netip.Addr]
    58  	main                  []dnsClient
    59  	fallback              []dnsClient
    60  	proxyServer           []dnsClient
    61  	remote                []dnsClient
    62  	fallbackDomainFilters []fallbackDomainFilter
    63  	fallbackIPFilters     []fallbackIPFilter
    64  	group                 singleflight.Group
    65  	lruCache              *cache.LruCache[string, *rMsg]
    66  	policy                *trie.DomainTrie[*Policy]
    67  	searchDomains         []string
    68  }
    69  
    70  // LookupIP request with TypeA and TypeAAAA, priority return TypeA
    71  func (r *Resolver) LookupIP(ctx context.Context, host string) (ip []netip.Addr, err error) {
    72  	ctx1, cancel := context.WithCancel(ctx)
    73  	defer cancel()
    74  
    75  	ch := make(chan []netip.Addr, 1)
    76  	go func() {
    77  		defer close(ch)
    78  		ip6, err6 := r.lookupIP(ctx1, host, D.TypeAAAA)
    79  		if err6 != nil {
    80  			return
    81  		}
    82  		ch <- ip6
    83  	}()
    84  
    85  	ip, err = r.lookupIP(ctx1, host, D.TypeA)
    86  	if err == nil {
    87  		if resolver.IsRemote(ctx) { // force combine ipv6 list for remote resolve DNS
    88  			if ip6, open := <-ch; open {
    89  				ip = append(ip, ip6...)
    90  			}
    91  		}
    92  		return
    93  	}
    94  
    95  	ip, open := <-ch
    96  	if !open {
    97  		return nil, resolver.ErrIPNotFound
    98  	}
    99  
   100  	return ip, nil
   101  }
   102  
   103  // ResolveIP request with TypeA and TypeAAAA, priority return TypeA
   104  func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) {
   105  	ips, err := r.LookupIP(context.Background(), host)
   106  	if err != nil {
   107  		return netip.Addr{}, err
   108  	} else if len(ips) == 0 {
   109  		return netip.Addr{}, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host)
   110  	}
   111  	return ips[rand.IntN(len(ips))], nil
   112  }
   113  
   114  // LookupIPv4 request with TypeA
   115  func (r *Resolver) LookupIPv4(ctx context.Context, host string) ([]netip.Addr, error) {
   116  	return r.lookupIP(ctx, host, D.TypeA)
   117  }
   118  
   119  // ResolveIPv4 request with TypeA
   120  func (r *Resolver) ResolveIPv4(host string) (ip netip.Addr, err error) {
   121  	ips, err := r.lookupIP(context.Background(), host, D.TypeA)
   122  	if err != nil {
   123  		return netip.Addr{}, err
   124  	} else if len(ips) == 0 {
   125  		return netip.Addr{}, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host)
   126  	}
   127  	return ips[rand.IntN(len(ips))], nil
   128  }
   129  
   130  // LookupIPv6 request with TypeAAAA
   131  func (r *Resolver) LookupIPv6(ctx context.Context, host string) ([]netip.Addr, error) {
   132  	return r.lookupIP(ctx, host, D.TypeAAAA)
   133  }
   134  
   135  // ResolveIPv6 request with TypeAAAA
   136  func (r *Resolver) ResolveIPv6(host string) (ip netip.Addr, err error) {
   137  	ips, err := r.lookupIP(context.Background(), host, D.TypeAAAA)
   138  	if err != nil {
   139  		return netip.Addr{}, err
   140  	} else if len(ips) == 0 {
   141  		return netip.Addr{}, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host)
   142  	}
   143  	return ips[rand.IntN(len(ips))], nil
   144  }
   145  
   146  func (r *Resolver) shouldIPFallback(ip netip.Addr) bool {
   147  	for _, filter := range r.fallbackIPFilters {
   148  		if filter.Match(ip) {
   149  			return true
   150  		}
   151  	}
   152  	return false
   153  }
   154  
   155  // Exchange a batch of dns request, and it uses cache
   156  func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, source string, err error) {
   157  	return r.ExchangeContext(context.Background(), m)
   158  }
   159  
   160  // ExchangeContext a batch of dns request with context.Context, and it uses cache
   161  func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, source string, err error) {
   162  	if len(m.Question) == 0 {
   163  		return nil, "", errors.New("should have one question at least")
   164  	}
   165  
   166  	var (
   167  		q   = m.Question[0]
   168  		key = genMsgCacheKey(ctx, q)
   169  	)
   170  
   171  	cacheM, expireTime, hit := r.lruCache.GetWithExpire(key)
   172  	if hit && time.Now().Before(expireTime) {
   173  		msg1 := cacheM.Copy()
   174  		msg = msg1.Msg
   175  		source = msg1.Source
   176  		setMsgMaxTTL(msg, uint32(time.Until(expireTime).Seconds()))
   177  		return
   178  	}
   179  	msg1, err := r.exchangeWithoutCache(ctx, m, q, key, true)
   180  	if err != nil {
   181  		return nil, "", err
   182  	}
   183  	return msg1.Msg, msg1.Source, nil
   184  }
   185  
   186  // ExchangeContextWithoutCache a batch of dns request with context.Context
   187  func (r *Resolver) ExchangeContextWithoutCache(ctx context.Context, m *D.Msg) (msg *D.Msg, source string, err error) {
   188  	if len(m.Question) == 0 {
   189  		return nil, "", errors.New("should have one question at least")
   190  	}
   191  
   192  	var (
   193  		q   = m.Question[0]
   194  		key = genMsgCacheKey(ctx, q)
   195  	)
   196  
   197  	msg1, err := r.exchangeWithoutCache(ctx, m, q, key, false)
   198  	if err != nil {
   199  		return nil, "", err
   200  	}
   201  	return msg1.Msg, msg1.Source, nil
   202  }
   203  
   204  // exchangeWithoutCache a batch of dns request, and it does NOT GET from cache
   205  func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg, q D.Question, key string, cache bool) (msg *rMsg, err error) {
   206  	domain := strings.TrimRight(q.Name, ".")
   207  	ret, err, shared := r.group.Do(key, func() (res any, err error) {
   208  		defer func() {
   209  			if err != nil || !cache {
   210  				return
   211  			}
   212  
   213  			msg1 := res.(*rMsg)
   214  
   215  			// OPT RRs MUST NOT be cached, forwarded, or stored in or loaded from master files.
   216  			msg1.Msg.Extra = lo.Filter(msg1.Msg.Extra, func(rr D.RR, index int) bool {
   217  				return rr.Header().Rrtype != D.TypeOPT
   218  			})
   219  
   220  			// skip dns cache for acme challenge
   221  			if q.Qtype == D.TypeTXT && strings.HasPrefix(q.Name, "_acme-challenge.") {
   222  				log.Debug().
   223  					Str("source", msg1.Source).
   224  					Str("qType", D.Type(q.Qtype).String()).
   225  					Str("name", q.Name).
   226  					Msg("[DNS] dns cache ignored because of acme challenge")
   227  				return
   228  			}
   229  
   230  			if resolver.IsProxyServer(ctx) {
   231  				// reset proxy server ip cache expire time to at least 20 minutes
   232  				sec := max(minTTL(msg1.Msg.Answer), 1200)
   233  				putMsgToCacheWithExpire(r.lruCache, key, msg1, sec)
   234  				return
   235  			}
   236  
   237  			if msg1.Msg.Rcode == D.RcodeNameError { // Non-Existent Domain
   238  				setTTL(msg1.Msg.Ns, 600, true)
   239  			}
   240  
   241  			putMsgToCache(r.lruCache, key, msg1)
   242  		}()
   243  
   244  		isIPReq := isIPRequest(q)
   245  		if isIPReq {
   246  			return r.ipExchange(ctx, m, domain)
   247  		}
   248  
   249  		var rst *result
   250  		if r.remote != nil && resolver.IsRemote(ctx) {
   251  			rst = r.exchangePolicyCombine(ctx, r.remote, m, domain)
   252  		} else if r.proxyServer != nil && resolver.IsProxyServer(ctx) {
   253  			rst = r.exchangePolicyCombine(ctx, r.proxyServer, m, domain)
   254  		} else {
   255  			rst = r.exchangePolicyCombine(ctx, r.main, m, domain)
   256  		}
   257  		return rst.Msg, rst.Error
   258  	})
   259  
   260  	if err == nil {
   261  		msg = ret.(*rMsg)
   262  		if shared {
   263  			msg = msg.Copy()
   264  		}
   265  	}
   266  
   267  	return
   268  }
   269  
   270  func (r *Resolver) matchPolicy(domain string) ([]dnsClient, bool) {
   271  	if r.policy == nil || domain == "" {
   272  		return nil, false
   273  	}
   274  
   275  	record := r.policy.Search(domain)
   276  	if record == nil {
   277  		return nil, false
   278  	}
   279  
   280  	return record.Data.GetData(), true
   281  }
   282  
   283  func (r *Resolver) exchangePolicyCombine(ctx context.Context, clients []dnsClient, m *D.Msg, domain string) *result {
   284  	timeout := resolver.DefaultDNSTimeout
   285  	if resolver.IsRemote(ctx) {
   286  		timeout = proxyTimeout
   287  	}
   288  
   289  	res := new(result)
   290  	policyClients, match := r.matchPolicy(domain)
   291  	if !match {
   292  		ctx1, cancel := context.WithTimeout(resolver.CopyCtxValues(ctx), timeout)
   293  		defer cancel()
   294  		res.Msg, res.Error = batchExchange(ctx1, clients, m)
   295  		return res
   296  	}
   297  
   298  	isLan := lo.SomeBy(policyClients, func(c dnsClient) bool {
   299  		return c.IsLan()
   300  	})
   301  
   302  	if !isLan {
   303  		ctx1, cancel := context.WithTimeout(resolver.CopyCtxValues(ctx), timeout)
   304  		defer cancel()
   305  		res.Msg, res.Error = batchExchange(ctx1, policyClients, m)
   306  		res.Policy = true
   307  		return res
   308  	}
   309  
   310  	var (
   311  		res1, res2 *result
   312  		done1      = atomic.NewBool(false)
   313  		wg         = sync.WaitGroup{}
   314  	)
   315  
   316  	wg.Add(2)
   317  
   318  	ctx1, cancel1 := context.WithTimeout(resolver.CopyCtxValues(ctx), resolver.DefaultDNSTimeout)
   319  	defer cancel1()
   320  
   321  	ctx2, cancel2 := context.WithTimeout(resolver.CopyCtxValues(ctx), timeout)
   322  	defer cancel2()
   323  
   324  	go func() {
   325  		msg, err := batchExchange(ctx1, policyClients, m)
   326  		res1 = &result{Msg: msg, Error: err, Policy: true}
   327  		done1.Store(true)
   328  		wg.Done()
   329  		if err == nil {
   330  			cancel2() // no need to wait for others
   331  		}
   332  	}()
   333  
   334  	go func() {
   335  		msg, err := batchExchange(ctx2, clients, m)
   336  		res2 = &result{Msg: msg, Error: err}
   337  		wg.Done()
   338  		if err == nil && !done1.Load() {
   339  			// if others done before lan policy, then wait maximum 50ms for lan policy
   340  			for i := 0; i < 10; i++ {
   341  				time.Sleep(5 * time.Millisecond)
   342  				if done1.Load() { // check for every 5ms
   343  					return
   344  				}
   345  			}
   346  			cancel1()
   347  		}
   348  	}()
   349  
   350  	wg.Wait()
   351  
   352  	if res1.Error == nil {
   353  		res = res1
   354  	} else {
   355  		res = res2
   356  	}
   357  
   358  	if res.Error == nil {
   359  		res.Msg.Lan = true
   360  		setMsgMaxTTL(res.Msg.Msg, 10) // reset ttl to maximum 10 seconds for lan policy
   361  	}
   362  	return res
   363  }
   364  
   365  func (r *Resolver) shouldOnlyQueryFallback(domain string) bool {
   366  	if r.fallback == nil || r.fallbackDomainFilters == nil || domain == "" {
   367  		return false
   368  	}
   369  
   370  	for _, df := range r.fallbackDomainFilters {
   371  		if df.Match(domain) {
   372  			return true
   373  		}
   374  	}
   375  
   376  	return false
   377  }
   378  
   379  func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg, domain string) (msg *rMsg, err error) {
   380  	if r.remote != nil && resolver.IsRemote(ctx) {
   381  		res := r.exchangePolicyCombine(ctx, r.remote, m, domain)
   382  		return res.Msg, res.Error
   383  	}
   384  
   385  	if r.proxyServer != nil && resolver.IsProxyServer(ctx) {
   386  		res := r.exchangePolicyCombine(ctx, r.proxyServer, m, domain)
   387  		return res.Msg, res.Error
   388  	}
   389  
   390  	if r.shouldOnlyQueryFallback(domain) {
   391  		res := r.exchangePolicyCombine(ctx, r.fallback, m, domain)
   392  		return res.Msg, res.Error
   393  	}
   394  
   395  	res := r.exchangePolicyCombine(ctx, r.main, m, domain)
   396  	msg, err = res.Msg, res.Error
   397  
   398  	if res.Policy { // directly return if from policy servers
   399  		return
   400  	}
   401  
   402  	if r.fallback == nil { // directly return if no fallback servers are available
   403  		return
   404  	}
   405  
   406  	if err == nil {
   407  		if ips := msgToIP(msg.Msg); len(ips) != 0 {
   408  			if lo.EveryBy(ips, func(ip netip.Addr) bool {
   409  				return !r.shouldIPFallback(ip)
   410  			}) {
   411  				// no need to wait for fallback result
   412  				return
   413  			}
   414  		}
   415  	}
   416  
   417  	res = r.exchangePolicyCombine(ctx, r.fallback, m, domain)
   418  	msg, err = res.Msg, res.Error
   419  	return
   420  }
   421  
   422  func (r *Resolver) lookupIP(ctx context.Context, host string, dnsType uint16) ([]netip.Addr, error) {
   423  	ip, err := netip.ParseAddr(host)
   424  	if err == nil {
   425  		if dnsType != D.TypeAAAA {
   426  			ip = ip.Unmap()
   427  		}
   428  		isIPv4 := ip.Is4()
   429  		if dnsType == D.TypeAAAA && !isIPv4 {
   430  			return []netip.Addr{ip}, nil
   431  		} else if dnsType == D.TypeA && isIPv4 {
   432  			return []netip.Addr{ip}, nil
   433  		} else {
   434  			return nil, resolver.ErrIPVersion
   435  		}
   436  	}
   437  
   438  	query := &D.Msg{}
   439  	query.SetQuestion(D.Fqdn(host), dnsType)
   440  
   441  	msg, _, err := r.ExchangeContext(ctx, query)
   442  	if err != nil {
   443  		return nil, err
   444  	}
   445  
   446  	ips := msgToIP(msg)
   447  	if len(ips) != 0 {
   448  		return ips, nil
   449  	} else if len(r.searchDomains) == 0 {
   450  		return nil, resolver.ErrIPNotFound
   451  	}
   452  
   453  	for _, domain := range r.searchDomains {
   454  		q := &D.Msg{}
   455  		q.SetQuestion(D.Fqdn(fmt.Sprintf("%s.%s", host, domain)), dnsType)
   456  		msg1, _, err1 := r.ExchangeContext(ctx, q)
   457  		if err1 != nil {
   458  			return nil, err1
   459  		}
   460  		ips1 := msgToIP(msg1)
   461  		if len(ips1) != 0 {
   462  			return ips1, nil
   463  		}
   464  	}
   465  
   466  	return nil, resolver.ErrIPNotFound
   467  }
   468  
   469  func (r *Resolver) RemoveCache(host string) {
   470  	q := D.Question{Name: D.Fqdn(host), Qtype: D.TypeA, Qclass: D.ClassINET}
   471  	r.lruCache.Delete(genMsgCacheKey(context.Background(), q))
   472  	q.Qtype = D.TypeAAAA
   473  	r.lruCache.Delete(genMsgCacheKey(context.Background(), q))
   474  }
   475  
   476  type NameServer struct {
   477  	Net       string
   478  	Addr      string
   479  	Interface string
   480  	Proxy     string
   481  	IsDHCP    bool
   482  }
   483  
   484  type FallbackFilter struct {
   485  	GeoIP     bool
   486  	GeoIPCode string
   487  	IPCIDR    []*netip.Prefix
   488  	Domain    []string
   489  	GeoSite   []*router.DomainMatcher
   490  }
   491  
   492  type Config struct {
   493  	Main, Fallback []NameServer
   494  	Default        []NameServer
   495  	ProxyServer    []NameServer
   496  	Remote         []NameServer
   497  	IPv6           bool
   498  	EnhancedMode   C.DNSMode
   499  	FallbackFilter FallbackFilter
   500  	Pool           *fakeip.Pool
   501  	Hosts          *trie.DomainTrie[netip.Addr]
   502  	Policy         map[string]NameServer
   503  	SearchDomains  []string
   504  }
   505  
   506  func NewResolver(config Config) *Resolver {
   507  	defaultResolver := &Resolver{
   508  		main: transform(config.Default, nil),
   509  		lruCache: cache.New[string, *rMsg](
   510  			cache.WithSize[string, *rMsg](128),
   511  			cache.WithStale[string, *rMsg](true),
   512  		),
   513  	}
   514  
   515  	r := &Resolver{
   516  		ipv6: config.IPv6,
   517  		main: transform(config.Main, defaultResolver),
   518  		lruCache: cache.New[string, *rMsg](
   519  			cache.WithSize[string, *rMsg](10240),
   520  			cache.WithStale[string, *rMsg](true),
   521  		),
   522  		hosts:         config.Hosts,
   523  		searchDomains: config.SearchDomains,
   524  	}
   525  
   526  	if len(config.Fallback) != 0 {
   527  		r.fallback = transform(config.Fallback, defaultResolver)
   528  	}
   529  
   530  	if len(config.ProxyServer) != 0 {
   531  		r.proxyServer = transform(config.ProxyServer, defaultResolver)
   532  	}
   533  
   534  	if len(config.Remote) != 0 {
   535  		remotes := lo.Map(config.Remote, func(item NameServer, _ int) NameServer {
   536  			item.Proxy = "remote-resolver"
   537  			return item
   538  		})
   539  		r.remote = transform(remotes, defaultResolver)
   540  	}
   541  
   542  	if len(config.Policy) != 0 {
   543  		r.policy = trie.New[*Policy]()
   544  		for domain, nameserver := range config.Policy {
   545  			_ = r.policy.Insert(domain, NewPolicy(transform([]NameServer{nameserver}, defaultResolver)))
   546  		}
   547  	}
   548  
   549  	var fallbackIPFilters []fallbackIPFilter
   550  	if config.FallbackFilter.GeoIP {
   551  		fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{
   552  			code: config.FallbackFilter.GeoIPCode,
   553  		})
   554  	}
   555  	for _, ipnet := range config.FallbackFilter.IPCIDR {
   556  		fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet})
   557  	}
   558  	r.fallbackIPFilters = fallbackIPFilters
   559  
   560  	var fallbackDomainFilters []fallbackDomainFilter
   561  	if len(config.FallbackFilter.Domain) != 0 {
   562  		fallbackDomainFilters = append(fallbackDomainFilters, NewDomainFilter(config.FallbackFilter.Domain))
   563  	}
   564  
   565  	if len(config.FallbackFilter.GeoSite) != 0 {
   566  		fallbackDomainFilters = append(fallbackDomainFilters, &geoSiteFilter{
   567  			matchers: config.FallbackFilter.GeoSite,
   568  		})
   569  	}
   570  	r.fallbackDomainFilters = fallbackDomainFilters
   571  
   572  	return r
   573  }