github.com/igoogolx/clash@v1.19.8/dns/resolver.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand"
     8  	"net"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/igoogolx/clash/common/cache"
    13  	"github.com/igoogolx/clash/component/fakeip"
    14  	"github.com/igoogolx/clash/component/resolver"
    15  	"github.com/igoogolx/clash/component/trie"
    16  	C "github.com/igoogolx/clash/constant"
    17  
    18  	D "github.com/miekg/dns"
    19  	"github.com/samber/lo"
    20  	"golang.org/x/sync/singleflight"
    21  )
    22  
    23  type dnsClient interface {
    24  	Exchange(m *D.Msg) (msg *D.Msg, err error)
    25  	ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error)
    26  }
    27  
    28  type result struct {
    29  	Msg   *D.Msg
    30  	Error error
    31  }
    32  
    33  type Resolver struct {
    34  	ipv6                  bool
    35  	hosts                 *trie.DomainTrie
    36  	main                  []dnsClient
    37  	fallback              []dnsClient
    38  	fallbackDomainFilters []fallbackDomainFilter
    39  	fallbackIPFilters     []fallbackIPFilter
    40  	group                 singleflight.Group
    41  	lruCache              *cache.LruCache
    42  	policy                *trie.DomainTrie
    43  	searchDomains         []string
    44  	disableCache          bool
    45  }
    46  
    47  // LookupIP request with TypeA and TypeAAAA, priority return TypeA
    48  func (r *Resolver) LookupIP(ctx context.Context, host string) (ip []net.IP, err error) {
    49  	ctx, cancel := context.WithCancel(ctx)
    50  	defer cancel()
    51  
    52  	ch := make(chan []net.IP, 1)
    53  
    54  	go func() {
    55  		defer close(ch)
    56  		ip, err := r.lookupIP(ctx, host, D.TypeAAAA)
    57  		if err != nil {
    58  			return
    59  		}
    60  		ch <- ip
    61  	}()
    62  
    63  	ip, err = r.lookupIP(ctx, host, D.TypeA)
    64  	if err == nil {
    65  		return
    66  	}
    67  
    68  	ip, open := <-ch
    69  	if !open {
    70  		return nil, resolver.ErrIPNotFound
    71  	}
    72  
    73  	return ip, nil
    74  }
    75  
    76  // ResolveIP request with TypeA and TypeAAAA, priority return TypeA
    77  func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) {
    78  	ips, err := r.LookupIP(context.Background(), host)
    79  	if err != nil {
    80  		return nil, err
    81  	} else if len(ips) == 0 {
    82  		return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host)
    83  	}
    84  	return ips[rand.Intn(len(ips))], nil
    85  }
    86  
    87  // LookupIPv4 request with TypeA
    88  func (r *Resolver) LookupIPv4(ctx context.Context, host string) ([]net.IP, error) {
    89  	return r.lookupIP(ctx, host, D.TypeA)
    90  }
    91  
    92  // ResolveIPv4 request with TypeA
    93  func (r *Resolver) ResolveIPv4(host string) (ip net.IP, err error) {
    94  	ips, err := r.lookupIP(context.Background(), host, D.TypeA)
    95  	if err != nil {
    96  		return nil, err
    97  	} else if len(ips) == 0 {
    98  		return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host)
    99  	}
   100  	return ips[rand.Intn(len(ips))], nil
   101  }
   102  
   103  // LookupIPv6 request with TypeAAAA
   104  func (r *Resolver) LookupIPv6(ctx context.Context, host string) ([]net.IP, error) {
   105  	return r.lookupIP(ctx, host, D.TypeAAAA)
   106  }
   107  
   108  // ResolveIPv6 request with TypeAAAA
   109  func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) {
   110  	ips, err := r.lookupIP(context.Background(), host, D.TypeAAAA)
   111  	if err != nil {
   112  		return nil, err
   113  	} else if len(ips) == 0 {
   114  		return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, host)
   115  	}
   116  	return ips[rand.Intn(len(ips))], nil
   117  }
   118  
   119  func (r *Resolver) shouldIPFallback(ip net.IP) bool {
   120  	for _, filter := range r.fallbackIPFilters {
   121  		if filter.Match(ip) {
   122  			return true
   123  		}
   124  	}
   125  	return false
   126  }
   127  
   128  // Exchange a batch of dns request, and it use cache
   129  func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) {
   130  	return r.ExchangeContext(context.Background(), m)
   131  }
   132  
   133  // ExchangeContext a batch of dns request with context.Context, and it use cache
   134  func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
   135  	if len(m.Question) == 0 {
   136  		return nil, errors.New("should have one question at least")
   137  	}
   138  
   139  	q := m.Question[0]
   140  	cache, expireTime, hit := r.lruCache.GetWithExpire(q.String())
   141  	if hit {
   142  		now := time.Now()
   143  		msg = cache.(*D.Msg).Copy()
   144  		if expireTime.Before(now) {
   145  			setMsgTTL(msg, uint32(1)) // Continue fetch
   146  			go func() {
   147  				ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout)
   148  				r.exchangeWithoutCache(ctx, m)
   149  				cancel()
   150  			}()
   151  		} else {
   152  			// updating TTL by subtracting common delta time from each DNS record
   153  			updateMsgTTL(msg, uint32(time.Until(expireTime).Seconds()))
   154  		}
   155  		return
   156  	}
   157  	return r.exchangeWithoutCache(ctx, m)
   158  }
   159  
   160  func (r *Resolver) exchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
   161  	q := m.Question[0]
   162  	isIPReq := isIPRequest(q)
   163  	if isIPReq {
   164  		return r.ipExchange(ctx, m)
   165  	}
   166  
   167  	if matched := r.matchPolicy(m); len(matched) != 0 {
   168  		return r.batchExchange(ctx, matched, m)
   169  	}
   170  	return r.batchExchange(ctx, r.main, m)
   171  }
   172  
   173  // ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache
   174  func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
   175  	if r.disableCache {
   176  		msg, err = r.exchange(ctx, m)
   177  	} else {
   178  		q := m.Question[0]
   179  		ret, err, shared := r.group.Do(q.String(), func() (result any, err error) {
   180  			defer func() {
   181  				if err != nil {
   182  					return
   183  				}
   184  
   185  				msg := result.(*D.Msg)
   186  				// OPT RRs MUST NOT be cached, forwarded, or stored in or loaded from master files.
   187  				msg.Extra = lo.Filter(msg.Extra, func(rr D.RR, index int) bool {
   188  					return rr.Header().Rrtype != D.TypeOPT
   189  				})
   190  				putMsgToCache(r.lruCache, q.String(), q, msg)
   191  			}()
   192  			return r.exchange(ctx, m)
   193  		})
   194  		if err == nil {
   195  			msg = ret.(*D.Msg)
   196  			if shared {
   197  				msg = msg.Copy()
   198  			}
   199  		}
   200  	}
   201  
   202  	return
   203  }
   204  
   205  func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {
   206  	ctx, cancel := context.WithTimeout(ctx, resolver.DefaultDNSTimeout)
   207  	defer cancel()
   208  
   209  	return batchExchange(ctx, clients, m)
   210  }
   211  
   212  func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient {
   213  	if r.policy == nil {
   214  		return nil
   215  	}
   216  
   217  	domain := r.msgToDomain(m)
   218  	if domain == "" {
   219  		return nil
   220  	}
   221  
   222  	record := r.policy.Search(domain)
   223  	if record == nil {
   224  		return nil
   225  	}
   226  
   227  	return record.Data.([]dnsClient)
   228  }
   229  
   230  func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
   231  	if r.fallback == nil || len(r.fallbackDomainFilters) == 0 {
   232  		return false
   233  	}
   234  
   235  	domain := r.msgToDomain(m)
   236  
   237  	if domain == "" {
   238  		return false
   239  	}
   240  
   241  	for _, df := range r.fallbackDomainFilters {
   242  		if df.Match(domain) {
   243  			return true
   244  		}
   245  	}
   246  
   247  	return false
   248  }
   249  
   250  func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
   251  	if matched := r.matchPolicy(m); len(matched) != 0 {
   252  		res := <-r.asyncExchange(ctx, matched, m)
   253  		return res.Msg, res.Error
   254  	}
   255  
   256  	onlyFallback := r.shouldOnlyQueryFallback(m)
   257  
   258  	if onlyFallback {
   259  		res := <-r.asyncExchange(ctx, r.fallback, m)
   260  		return res.Msg, res.Error
   261  	}
   262  
   263  	msgCh := r.asyncExchange(ctx, r.main, m)
   264  
   265  	if r.fallback == nil { // directly return if no fallback servers are available
   266  		res := <-msgCh
   267  		msg, err = res.Msg, res.Error
   268  		return
   269  	}
   270  
   271  	fallbackMsg := r.asyncExchange(ctx, r.fallback, m)
   272  	res := <-msgCh
   273  	if res.Error == nil {
   274  		if ips := msgToIP(res.Msg); len(ips) != 0 {
   275  			shouldNotFallback := lo.EveryBy(ips, func(ip net.IP) bool {
   276  				return !r.shouldIPFallback(ip)
   277  			})
   278  			if shouldNotFallback {
   279  				msg = res.Msg // no need to wait for fallback result
   280  				err = res.Error
   281  				return msg, err
   282  			}
   283  		}
   284  	}
   285  
   286  	res = <-fallbackMsg
   287  	msg, err = res.Msg, res.Error
   288  	return
   289  }
   290  
   291  func (r *Resolver) lookupIP(ctx context.Context, host string, dnsType uint16) ([]net.IP, error) {
   292  	ip := net.ParseIP(host)
   293  	if ip != nil {
   294  		ip4 := ip.To4()
   295  		isIPv4 := ip4 != nil
   296  		if dnsType == D.TypeAAAA && !isIPv4 {
   297  			return []net.IP{ip}, nil
   298  		} else if dnsType == D.TypeA && isIPv4 {
   299  			return []net.IP{ip4}, nil
   300  		} else {
   301  			return nil, resolver.ErrIPVersion
   302  		}
   303  	}
   304  
   305  	query := &D.Msg{}
   306  	query.SetQuestion(D.Fqdn(host), dnsType)
   307  
   308  	msg, err := r.ExchangeContext(ctx, query)
   309  	if err != nil {
   310  		return nil, err
   311  	}
   312  
   313  	ips := msgToIP(msg)
   314  	if len(ips) != 0 {
   315  		return ips, nil
   316  	} else if len(r.searchDomains) == 0 {
   317  		return nil, resolver.ErrIPNotFound
   318  	}
   319  
   320  	// query provided search domains serially
   321  	for _, domain := range r.searchDomains {
   322  		q := &D.Msg{}
   323  		q.SetQuestion(D.Fqdn(fmt.Sprintf("%s.%s", host, domain)), dnsType)
   324  		msg, err := r.ExchangeContext(ctx, q)
   325  		if err != nil {
   326  			return nil, err
   327  		}
   328  		ips := msgToIP(msg)
   329  		if len(ips) != 0 {
   330  			return ips, nil
   331  		}
   332  	}
   333  
   334  	return nil, resolver.ErrIPNotFound
   335  }
   336  
   337  func (r *Resolver) msgToDomain(msg *D.Msg) string {
   338  	if len(msg.Question) > 0 {
   339  		return strings.TrimRight(msg.Question[0].Name, ".")
   340  	}
   341  
   342  	return ""
   343  }
   344  
   345  func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D.Msg) <-chan *result {
   346  	ch := make(chan *result, 1)
   347  	go func() {
   348  		res, err := r.batchExchange(ctx, client, msg)
   349  		ch <- &result{Msg: res, Error: err}
   350  	}()
   351  	return ch
   352  }
   353  
   354  type NameServer struct {
   355  	Net       string
   356  	Addr      string
   357  	Interface string
   358  }
   359  
   360  type FallbackFilter struct {
   361  	GeoIP     bool
   362  	GeoIPCode string
   363  	IPCIDR    []*net.IPNet
   364  	Domain    []string
   365  }
   366  
   367  type Config struct {
   368  	Main, Fallback []NameServer
   369  	Default        []NameServer
   370  	IPv6           bool
   371  	EnhancedMode   C.DNSMode
   372  	FallbackFilter FallbackFilter
   373  	Pool           *fakeip.Pool
   374  	Hosts          *trie.DomainTrie
   375  	Policy         map[string]NameServer
   376  	SearchDomains  []string
   377  	DisableCache   bool
   378  	GetDialer      func() (C.Proxy, error)
   379  }
   380  
   381  func NewResolver(config Config) *Resolver {
   382  
   383  	r := &Resolver{
   384  		ipv6:          config.IPv6,
   385  		main:          transform(config.Main, config.GetDialer),
   386  		lruCache:      cache.New(cache.WithSize(4096), cache.WithStale(true)),
   387  		hosts:         config.Hosts,
   388  		searchDomains: config.SearchDomains,
   389  		disableCache:  config.DisableCache,
   390  	}
   391  
   392  	if len(config.Fallback) != 0 {
   393  		r.fallback = transform(config.Fallback, config.GetDialer)
   394  	}
   395  
   396  	if len(config.Policy) != 0 {
   397  		r.policy = trie.New()
   398  		for domain, nameserver := range config.Policy {
   399  			r.policy.Insert(domain, transform([]NameServer{nameserver}, config.GetDialer))
   400  		}
   401  	}
   402  
   403  	fallbackIPFilters := []fallbackIPFilter{}
   404  	if config.FallbackFilter.GeoIP {
   405  		fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{
   406  			code: config.FallbackFilter.GeoIPCode,
   407  		})
   408  	}
   409  	for _, ipnet := range config.FallbackFilter.IPCIDR {
   410  		fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet})
   411  	}
   412  	r.fallbackIPFilters = fallbackIPFilters
   413  
   414  	if len(config.FallbackFilter.Domain) != 0 {
   415  		fallbackDomainFilters := []fallbackDomainFilter{NewDomainFilter(config.FallbackFilter.Domain)}
   416  		r.fallbackDomainFilters = fallbackDomainFilters
   417  	}
   418  
   419  	return r
   420  }