github.com/metacubex/mihomo@v1.18.5/dns/resolver.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net/netip"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/metacubex/mihomo/common/arc"
    11  	"github.com/metacubex/mihomo/common/lru"
    12  	"github.com/metacubex/mihomo/component/fakeip"
    13  	"github.com/metacubex/mihomo/component/geodata/router"
    14  	"github.com/metacubex/mihomo/component/resolver"
    15  	"github.com/metacubex/mihomo/component/trie"
    16  	C "github.com/metacubex/mihomo/constant"
    17  	"github.com/metacubex/mihomo/constant/provider"
    18  	"github.com/metacubex/mihomo/log"
    19  
    20  	D "github.com/miekg/dns"
    21  	"github.com/samber/lo"
    22  	orderedmap "github.com/wk8/go-ordered-map/v2"
    23  	"golang.org/x/exp/maps"
    24  	"golang.org/x/sync/singleflight"
    25  )
    26  
    27  type dnsClient interface {
    28  	ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error)
    29  	Address() string
    30  }
    31  
    32  type dnsCache interface {
    33  	GetWithExpire(key string) (*D.Msg, time.Time, bool)
    34  	SetWithExpire(key string, value *D.Msg, expire time.Time)
    35  }
    36  
    37  type result struct {
    38  	Msg   *D.Msg
    39  	Error error
    40  }
    41  
    42  type Resolver struct {
    43  	ipv6                  bool
    44  	ipv6Timeout           time.Duration
    45  	hosts                 *trie.DomainTrie[resolver.HostValue]
    46  	main                  []dnsClient
    47  	fallback              []dnsClient
    48  	fallbackDomainFilters []fallbackDomainFilter
    49  	fallbackIPFilters     []fallbackIPFilter
    50  	group                 singleflight.Group
    51  	cache                 dnsCache
    52  	policy                []dnsPolicy
    53  	proxyServer           []dnsClient
    54  }
    55  
    56  func (r *Resolver) LookupIPPrimaryIPv4(ctx context.Context, host string) (ips []netip.Addr, err error) {
    57  	ch := make(chan []netip.Addr, 1)
    58  	go func() {
    59  		defer close(ch)
    60  		ip, err := r.lookupIP(ctx, host, D.TypeAAAA)
    61  		if err != nil {
    62  			return
    63  		}
    64  		ch <- ip
    65  	}()
    66  
    67  	ips, err = r.lookupIP(ctx, host, D.TypeA)
    68  	if err == nil {
    69  		return
    70  	}
    71  
    72  	ip, open := <-ch
    73  	if !open {
    74  		return nil, resolver.ErrIPNotFound
    75  	}
    76  
    77  	return ip, nil
    78  }
    79  
    80  func (r *Resolver) LookupIP(ctx context.Context, host string) (ips []netip.Addr, err error) {
    81  	ch := make(chan []netip.Addr, 1)
    82  	go func() {
    83  		defer close(ch)
    84  		ip, err := r.lookupIP(ctx, host, D.TypeAAAA)
    85  		if err != nil {
    86  			return
    87  		}
    88  
    89  		ch <- ip
    90  	}()
    91  
    92  	ips, err = r.lookupIP(ctx, host, D.TypeA)
    93  	var waitIPv6 *time.Timer
    94  	if r != nil && r.ipv6Timeout > 0 {
    95  		waitIPv6 = time.NewTimer(r.ipv6Timeout)
    96  	} else {
    97  		waitIPv6 = time.NewTimer(100 * time.Millisecond)
    98  	}
    99  	defer waitIPv6.Stop()
   100  	select {
   101  	case ipv6s, open := <-ch:
   102  		if !open && err != nil {
   103  			return nil, resolver.ErrIPNotFound
   104  		}
   105  		ips = append(ips, ipv6s...)
   106  	case <-waitIPv6.C:
   107  		// wait ipv6 result
   108  	}
   109  
   110  	return ips, nil
   111  }
   112  
   113  // LookupIPv4 request with TypeA
   114  func (r *Resolver) LookupIPv4(ctx context.Context, host string) ([]netip.Addr, error) {
   115  	return r.lookupIP(ctx, host, D.TypeA)
   116  }
   117  
   118  // LookupIPv6 request with TypeAAAA
   119  func (r *Resolver) LookupIPv6(ctx context.Context, host string) ([]netip.Addr, error) {
   120  	return r.lookupIP(ctx, host, D.TypeAAAA)
   121  }
   122  
   123  func (r *Resolver) shouldIPFallback(ip netip.Addr) bool {
   124  	for _, filter := range r.fallbackIPFilters {
   125  		if filter.Match(ip) {
   126  			return true
   127  		}
   128  	}
   129  	return false
   130  }
   131  
   132  // ExchangeContext a batch of dns request with context.Context, and it use cache
   133  func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
   134  	if len(m.Question) == 0 {
   135  		return nil, errors.New("should have one question at least")
   136  	}
   137  	continueFetch := false
   138  	defer func() {
   139  		if continueFetch || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
   140  			go func() {
   141  				ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout)
   142  				defer cancel()
   143  				_, _ = r.exchangeWithoutCache(ctx, m) // ignore result, just for putMsgToCache
   144  			}()
   145  		}
   146  	}()
   147  
   148  	q := m.Question[0]
   149  	cacheM, expireTime, hit := r.cache.GetWithExpire(q.String())
   150  	if hit {
   151  		log.Debugln("[DNS] cache hit for %s, expire at %s", q.Name, expireTime.Format("2006-01-02 15:04:05"))
   152  		now := time.Now()
   153  		msg = cacheM.Copy()
   154  		if expireTime.Before(now) {
   155  			setMsgTTL(msg, uint32(1)) // Continue fetch
   156  			continueFetch = true
   157  		} else {
   158  			// updating TTL by subtracting common delta time from each DNS record
   159  			updateMsgTTL(msg, uint32(time.Until(expireTime).Seconds()))
   160  		}
   161  		return
   162  	}
   163  	return r.exchangeWithoutCache(ctx, m)
   164  }
   165  
   166  // ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache
   167  func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
   168  	q := m.Question[0]
   169  
   170  	retryNum := 0
   171  	retryMax := 3
   172  	fn := func() (result any, err error) {
   173  		ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) // reset timeout in singleflight
   174  		defer cancel()
   175  		cache := false
   176  
   177  		defer func() {
   178  			if err != nil {
   179  				result = retryNum
   180  				retryNum++
   181  				return
   182  			}
   183  
   184  			msg := result.(*D.Msg)
   185  
   186  			if cache {
   187  				// OPT RRs MUST NOT be cached, forwarded, or stored in or loaded from master files.
   188  				msg.Extra = lo.Filter(msg.Extra, func(rr D.RR, index int) bool {
   189  					return rr.Header().Rrtype != D.TypeOPT
   190  				})
   191  				putMsgToCache(r.cache, q.String(), q, msg)
   192  			}
   193  		}()
   194  
   195  		isIPReq := isIPRequest(q)
   196  		if isIPReq {
   197  			cache = true
   198  			return r.ipExchange(ctx, m)
   199  		}
   200  
   201  		if matched := r.matchPolicy(m); len(matched) != 0 {
   202  			result, cache, err = batchExchange(ctx, matched, m)
   203  			return
   204  		}
   205  		result, cache, err = batchExchange(ctx, r.main, m)
   206  		return
   207  	}
   208  
   209  	ch := r.group.DoChan(q.String(), fn)
   210  
   211  	var result singleflight.Result
   212  
   213  	select {
   214  	case result = <-ch:
   215  		break
   216  	case <-ctx.Done():
   217  		select {
   218  		case result = <-ch: // maybe ctxDone and chFinish in same time, get DoChan's result as much as possible
   219  			break
   220  		default:
   221  			go func() { // start a retrying monitor in background
   222  				result := <-ch
   223  				ret, err, shared := result.Val, result.Err, result.Shared
   224  				if err != nil && !shared && ret.(int) < retryMax { // retry
   225  					r.group.DoChan(q.String(), fn)
   226  				}
   227  			}()
   228  			return nil, ctx.Err()
   229  		}
   230  	}
   231  
   232  	ret, err, shared := result.Val, result.Err, result.Shared
   233  	if err != nil && !shared && ret.(int) < retryMax { // retry
   234  		r.group.DoChan(q.String(), fn)
   235  	}
   236  
   237  	if err == nil {
   238  		msg = ret.(*D.Msg)
   239  		if shared {
   240  			msg = msg.Copy()
   241  		}
   242  	}
   243  
   244  	return
   245  }
   246  
   247  func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient {
   248  	if r.policy == nil {
   249  		return nil
   250  	}
   251  
   252  	domain := msgToDomain(m)
   253  	if domain == "" {
   254  		return nil
   255  	}
   256  
   257  	for _, policy := range r.policy {
   258  		if dnsClients := policy.Match(domain); len(dnsClients) > 0 {
   259  			return dnsClients
   260  		}
   261  	}
   262  	return nil
   263  }
   264  
   265  func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
   266  	if r.fallback == nil || len(r.fallbackDomainFilters) == 0 {
   267  		return false
   268  	}
   269  
   270  	domain := msgToDomain(m)
   271  
   272  	if domain == "" {
   273  		return false
   274  	}
   275  
   276  	for _, df := range r.fallbackDomainFilters {
   277  		if df.Match(domain) {
   278  			return true
   279  		}
   280  	}
   281  
   282  	return false
   283  }
   284  
   285  func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
   286  	if matched := r.matchPolicy(m); len(matched) != 0 {
   287  		res := <-r.asyncExchange(ctx, matched, m)
   288  		return res.Msg, res.Error
   289  	}
   290  
   291  	onlyFallback := r.shouldOnlyQueryFallback(m)
   292  
   293  	if onlyFallback {
   294  		res := <-r.asyncExchange(ctx, r.fallback, m)
   295  		return res.Msg, res.Error
   296  	}
   297  
   298  	msgCh := r.asyncExchange(ctx, r.main, m)
   299  
   300  	if r.fallback == nil || len(r.fallback) == 0 { // directly return if no fallback servers are available
   301  		res := <-msgCh
   302  		msg, err = res.Msg, res.Error
   303  		return
   304  	}
   305  
   306  	res := <-msgCh
   307  	if res.Error == nil {
   308  		if ips := msgToIP(res.Msg); len(ips) != 0 {
   309  			shouldNotFallback := lo.EveryBy(ips, func(ip netip.Addr) bool {
   310  				return !r.shouldIPFallback(ip)
   311  			})
   312  			if shouldNotFallback {
   313  				msg, err = res.Msg, res.Error // no need to wait for fallback result
   314  				return
   315  			}
   316  		}
   317  	}
   318  
   319  	res = <-r.asyncExchange(ctx, r.fallback, m)
   320  	msg, err = res.Msg, res.Error
   321  	return
   322  }
   323  
   324  func (r *Resolver) lookupIP(ctx context.Context, host string, dnsType uint16) (ips []netip.Addr, err error) {
   325  	ip, err := netip.ParseAddr(host)
   326  	if err == nil {
   327  		isIPv4 := ip.Is4() || ip.Is4In6()
   328  		if dnsType == D.TypeAAAA && !isIPv4 {
   329  			return []netip.Addr{ip}, nil
   330  		} else if dnsType == D.TypeA && isIPv4 {
   331  			return []netip.Addr{ip}, nil
   332  		} else {
   333  			return []netip.Addr{}, resolver.ErrIPVersion
   334  		}
   335  	}
   336  
   337  	query := &D.Msg{}
   338  	query.SetQuestion(D.Fqdn(host), dnsType)
   339  
   340  	msg, err := r.ExchangeContext(ctx, query)
   341  	if err != nil {
   342  		return []netip.Addr{}, err
   343  	}
   344  
   345  	ips = msgToIP(msg)
   346  	ipLength := len(ips)
   347  	if ipLength == 0 {
   348  		return []netip.Addr{}, resolver.ErrIPNotFound
   349  	}
   350  
   351  	return
   352  }
   353  
   354  func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D.Msg) <-chan *result {
   355  	ch := make(chan *result, 1)
   356  	go func() {
   357  		res, _, err := batchExchange(ctx, client, msg)
   358  		ch <- &result{Msg: res, Error: err}
   359  	}()
   360  	return ch
   361  }
   362  
   363  // Invalid return this resolver can or can't be used
   364  func (r *Resolver) Invalid() bool {
   365  	if r == nil {
   366  		return false
   367  	}
   368  	return len(r.main) > 0
   369  }
   370  
   371  type NameServer struct {
   372  	Net          string
   373  	Addr         string
   374  	Interface    string
   375  	ProxyAdapter C.ProxyAdapter
   376  	ProxyName    string
   377  	Params       map[string]string
   378  	PreferH3     bool
   379  }
   380  
   381  func (ns NameServer) Equal(ns2 NameServer) bool {
   382  	defer func() {
   383  		// C.ProxyAdapter compare maybe panic, just ignore
   384  		recover()
   385  	}()
   386  	if ns.Net == ns2.Net &&
   387  		ns.Addr == ns2.Addr &&
   388  		ns.Interface == ns2.Interface &&
   389  		ns.ProxyAdapter == ns2.ProxyAdapter &&
   390  		ns.ProxyName == ns2.ProxyName &&
   391  		maps.Equal(ns.Params, ns2.Params) &&
   392  		ns.PreferH3 == ns2.PreferH3 {
   393  		return true
   394  	}
   395  	return false
   396  }
   397  
   398  type FallbackFilter struct {
   399  	GeoIP     bool
   400  	GeoIPCode string
   401  	IPCIDR    []netip.Prefix
   402  	Domain    []string
   403  	GeoSite   []router.DomainMatcher
   404  }
   405  
   406  type Config struct {
   407  	Main, Fallback []NameServer
   408  	Default        []NameServer
   409  	ProxyServer    []NameServer
   410  	IPv6           bool
   411  	IPv6Timeout    uint
   412  	EnhancedMode   C.DNSMode
   413  	FallbackFilter FallbackFilter
   414  	Pool           *fakeip.Pool
   415  	Hosts          *trie.DomainTrie[resolver.HostValue]
   416  	Policy         *orderedmap.OrderedMap[string, []NameServer]
   417  	RuleProviders  map[string]provider.RuleProvider
   418  	CacheAlgorithm string
   419  }
   420  
   421  func NewResolver(config Config) *Resolver {
   422  	var cache dnsCache
   423  	if config.CacheAlgorithm == "lru" {
   424  		cache = lru.New(lru.WithSize[string, *D.Msg](4096), lru.WithStale[string, *D.Msg](true))
   425  	} else {
   426  		cache = arc.New(arc.WithSize[string, *D.Msg](4096))
   427  	}
   428  	defaultResolver := &Resolver{
   429  		main:        transform(config.Default, nil),
   430  		cache:       cache,
   431  		ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond,
   432  	}
   433  
   434  	var nameServerCache []struct {
   435  		NameServer
   436  		dnsClient
   437  	}
   438  	cacheTransform := func(nameserver []NameServer) (result []dnsClient) {
   439  	LOOP:
   440  		for _, ns := range nameserver {
   441  			for _, nsc := range nameServerCache {
   442  				if nsc.NameServer.Equal(ns) {
   443  					result = append(result, nsc.dnsClient)
   444  					continue LOOP
   445  				}
   446  			}
   447  			// not in cache
   448  			dc := transform([]NameServer{ns}, defaultResolver)
   449  			if len(dc) > 0 {
   450  				dc := dc[0]
   451  				nameServerCache = append(nameServerCache, struct {
   452  					NameServer
   453  					dnsClient
   454  				}{NameServer: ns, dnsClient: dc})
   455  				result = append(result, dc)
   456  			}
   457  		}
   458  		return
   459  	}
   460  
   461  	if config.CacheAlgorithm == "" || config.CacheAlgorithm == "lru" {
   462  		cache = lru.New(lru.WithSize[string, *D.Msg](4096), lru.WithStale[string, *D.Msg](true))
   463  	} else {
   464  		cache = arc.New(arc.WithSize[string, *D.Msg](4096))
   465  	}
   466  	r := &Resolver{
   467  		ipv6:        config.IPv6,
   468  		main:        cacheTransform(config.Main),
   469  		cache:       cache,
   470  		hosts:       config.Hosts,
   471  		ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond,
   472  	}
   473  
   474  	if len(config.Fallback) != 0 {
   475  		r.fallback = cacheTransform(config.Fallback)
   476  	}
   477  
   478  	if len(config.ProxyServer) != 0 {
   479  		r.proxyServer = cacheTransform(config.ProxyServer)
   480  	}
   481  
   482  	if config.Policy.Len() != 0 {
   483  		r.policy = make([]dnsPolicy, 0)
   484  
   485  		var triePolicy *trie.DomainTrie[[]dnsClient]
   486  		insertPolicy := func(policy dnsPolicy) {
   487  			if triePolicy != nil {
   488  				triePolicy.Optimize()
   489  				r.policy = append(r.policy, domainTriePolicy{triePolicy})
   490  				triePolicy = nil
   491  			}
   492  			if policy != nil {
   493  				r.policy = append(r.policy, policy)
   494  			}
   495  		}
   496  
   497  		for pair := config.Policy.Oldest(); pair != nil; pair = pair.Next() {
   498  			domain, nameserver := pair.Key, pair.Value
   499  
   500  			if temp := strings.Split(domain, ":"); len(temp) == 2 {
   501  				prefix := temp[0]
   502  				key := temp[1]
   503  				switch prefix {
   504  				case "rule-set":
   505  					if p, ok := config.RuleProviders[key]; ok {
   506  						log.Debugln("Adding rule-set policy: %s ", key)
   507  						insertPolicy(domainSetPolicy{
   508  							domainSetProvider: p,
   509  							dnsClients:        cacheTransform(nameserver),
   510  						})
   511  						continue
   512  					} else {
   513  						log.Warnln("Can't found ruleset policy: %s", key)
   514  					}
   515  				case "geosite":
   516  					inverse := false
   517  					if strings.HasPrefix(key, "!") {
   518  						inverse = true
   519  						key = key[1:]
   520  					}
   521  					log.Debugln("Adding geosite policy: %s inversed %t", key, inverse)
   522  					matcher, err := NewGeoSite(key)
   523  					if err != nil {
   524  						log.Warnln("adding geosite policy %s error: %s", key, err)
   525  						continue
   526  					}
   527  					insertPolicy(geositePolicy{
   528  						matcher:    matcher,
   529  						inverse:    inverse,
   530  						dnsClients: cacheTransform(nameserver),
   531  					})
   532  					continue // skip triePolicy new
   533  				}
   534  			}
   535  			if triePolicy == nil {
   536  				triePolicy = trie.New[[]dnsClient]()
   537  			}
   538  			_ = triePolicy.Insert(domain, cacheTransform(nameserver))
   539  		}
   540  		insertPolicy(nil)
   541  	}
   542  
   543  	fallbackIPFilters := []fallbackIPFilter{}
   544  	if config.FallbackFilter.GeoIP {
   545  		fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{
   546  			code: config.FallbackFilter.GeoIPCode,
   547  		})
   548  	}
   549  	for _, ipnet := range config.FallbackFilter.IPCIDR {
   550  		fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet})
   551  	}
   552  	r.fallbackIPFilters = fallbackIPFilters
   553  
   554  	fallbackDomainFilters := []fallbackDomainFilter{}
   555  	if len(config.FallbackFilter.Domain) != 0 {
   556  		fallbackDomainFilters = append(fallbackDomainFilters, NewDomainFilter(config.FallbackFilter.Domain))
   557  	}
   558  
   559  	if len(config.FallbackFilter.GeoSite) != 0 {
   560  		fallbackDomainFilters = append(fallbackDomainFilters, &geoSiteFilter{
   561  			matchers: config.FallbackFilter.GeoSite,
   562  		})
   563  	}
   564  	r.fallbackDomainFilters = fallbackDomainFilters
   565  
   566  	return r
   567  }
   568  
   569  func NewProxyServerHostResolver(old *Resolver) *Resolver {
   570  	r := &Resolver{
   571  		ipv6:        old.ipv6,
   572  		main:        old.proxyServer,
   573  		cache:       old.cache,
   574  		hosts:       old.hosts,
   575  		ipv6Timeout: old.ipv6Timeout,
   576  	}
   577  	return r
   578  }
   579  
   580  var ParseNameServer func(servers []string) ([]NameServer, error) // define in config/config.go