github.com/chwjbn/xclash@v0.2.0/dns/resolver.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand"
     8  	"net"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/chwjbn/xclash/common/cache"
    13  	"github.com/chwjbn/xclash/common/picker"
    14  	"github.com/chwjbn/xclash/component/fakeip"
    15  	"github.com/chwjbn/xclash/component/resolver"
    16  	"github.com/chwjbn/xclash/component/trie"
    17  	C "github.com/chwjbn/xclash/constant"
    18  
    19  	D "github.com/miekg/dns"
    20  	"golang.org/x/sync/singleflight"
    21  )
    22  
    23  type dnsClient interface {
    24  	Exchange(m *D.Msg) (msg *D.Msg, err error)
    25  	ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error)
    26  }
    27  
    28  type result struct {
    29  	Msg   *D.Msg
    30  	Error error
    31  }
    32  
    33  type Resolver struct {
    34  	ipv6                  bool
    35  	hosts                 *trie.DomainTrie
    36  	main                  []dnsClient
    37  	fallback              []dnsClient
    38  	fallbackDomainFilters []fallbackDomainFilter
    39  	fallbackIPFilters     []fallbackIPFilter
    40  	group                 singleflight.Group
    41  	lruCache              *cache.LruCache
    42  	policy                *trie.DomainTrie
    43  }
    44  
    45  // ResolveIP request with TypeA and TypeAAAA, priority return TypeA
    46  func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) {
    47  	ch := make(chan net.IP, 1)
    48  	go func() {
    49  		defer close(ch)
    50  		ip, err := r.resolveIP(host, D.TypeAAAA)
    51  		if err != nil {
    52  			return
    53  		}
    54  		ch <- ip
    55  	}()
    56  
    57  	ip, err = r.resolveIP(host, D.TypeA)
    58  	if err == nil {
    59  		return
    60  	}
    61  
    62  	ip, open := <-ch
    63  	if !open {
    64  		return nil, resolver.ErrIPNotFound
    65  	}
    66  
    67  	return ip, nil
    68  }
    69  
    70  // ResolveIPv4 request with TypeA
    71  func (r *Resolver) ResolveIPv4(host string) (ip net.IP, err error) {
    72  	return r.resolveIP(host, D.TypeA)
    73  }
    74  
    75  // ResolveIPv6 request with TypeAAAA
    76  func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) {
    77  	return r.resolveIP(host, D.TypeAAAA)
    78  }
    79  
    80  func (r *Resolver) shouldIPFallback(ip net.IP) bool {
    81  	for _, filter := range r.fallbackIPFilters {
    82  		if filter.Match(ip) {
    83  			return true
    84  		}
    85  	}
    86  	return false
    87  }
    88  
    89  // Exchange a batch of dns request, and it use cache
    90  func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) {
    91  	return r.ExchangeContext(context.Background(), m)
    92  }
    93  
    94  // ExchangeContext a batch of dns request with context.Context, and it use cache
    95  func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
    96  	if len(m.Question) == 0 {
    97  		return nil, errors.New("should have one question at least")
    98  	}
    99  
   100  	q := m.Question[0]
   101  	cache, expireTime, hit := r.lruCache.GetWithExpire(q.String())
   102  	if hit {
   103  		now := time.Now()
   104  		msg = cache.(*D.Msg).Copy()
   105  		if expireTime.Before(now) {
   106  			setMsgTTL(msg, uint32(1)) // Continue fetch
   107  			go r.exchangeWithoutCache(ctx, m)
   108  		} else {
   109  			setMsgTTL(msg, uint32(time.Until(expireTime).Seconds()))
   110  		}
   111  		return
   112  	}
   113  	return r.exchangeWithoutCache(ctx, m)
   114  }
   115  
   116  // ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache
   117  func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
   118  	q := m.Question[0]
   119  
   120  	ret, err, shared := r.group.Do(q.String(), func() (result interface{}, err error) {
   121  		defer func() {
   122  			if err != nil {
   123  				return
   124  			}
   125  
   126  			msg := result.(*D.Msg)
   127  
   128  			putMsgToCache(r.lruCache, q.String(), msg)
   129  		}()
   130  
   131  		isIPReq := isIPRequest(q)
   132  		if isIPReq {
   133  			return r.ipExchange(ctx, m)
   134  		}
   135  
   136  		if matched := r.matchPolicy(m); len(matched) != 0 {
   137  			return r.batchExchange(ctx, matched, m)
   138  		}
   139  		return r.batchExchange(ctx, r.main, m)
   140  	})
   141  
   142  	if err == nil {
   143  		msg = ret.(*D.Msg)
   144  		if shared {
   145  			msg = msg.Copy()
   146  		}
   147  	}
   148  
   149  	return
   150  }
   151  
   152  func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {
   153  	fast, ctx := picker.WithTimeout(ctx, resolver.DefaultDNSTimeout)
   154  	for _, client := range clients {
   155  		r := client
   156  		fast.Go(func() (interface{}, error) {
   157  			m, err := r.ExchangeContext(ctx, m)
   158  			if err != nil {
   159  				return nil, err
   160  			} else if m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused {
   161  				return nil, errors.New("server failure")
   162  			}
   163  			return m, nil
   164  		})
   165  	}
   166  
   167  	elm := fast.Wait()
   168  	if elm == nil {
   169  		err := errors.New("all DNS requests failed")
   170  		if fErr := fast.Error(); fErr != nil {
   171  			err = fmt.Errorf("%w, first error: %s", err, fErr.Error())
   172  		}
   173  		return nil, err
   174  	}
   175  
   176  	msg = elm.(*D.Msg)
   177  	return
   178  }
   179  
   180  func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient {
   181  	if r.policy == nil {
   182  		return nil
   183  	}
   184  
   185  	domain := r.msgToDomain(m)
   186  	if domain == "" {
   187  		return nil
   188  	}
   189  
   190  	record := r.policy.Search(domain)
   191  	if record == nil {
   192  		return nil
   193  	}
   194  
   195  	return record.Data.([]dnsClient)
   196  }
   197  
   198  func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
   199  	if r.fallback == nil || len(r.fallbackDomainFilters) == 0 {
   200  		return false
   201  	}
   202  
   203  	domain := r.msgToDomain(m)
   204  
   205  	if domain == "" {
   206  		return false
   207  	}
   208  
   209  	for _, df := range r.fallbackDomainFilters {
   210  		if df.Match(domain) {
   211  			return true
   212  		}
   213  	}
   214  
   215  	return false
   216  }
   217  
   218  func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
   219  	if matched := r.matchPolicy(m); len(matched) != 0 {
   220  		res := <-r.asyncExchange(ctx, matched, m)
   221  		return res.Msg, res.Error
   222  	}
   223  
   224  	onlyFallback := r.shouldOnlyQueryFallback(m)
   225  
   226  	if onlyFallback {
   227  		res := <-r.asyncExchange(ctx, r.fallback, m)
   228  		return res.Msg, res.Error
   229  	}
   230  
   231  	msgCh := r.asyncExchange(ctx, r.main, m)
   232  
   233  	if r.fallback == nil { // directly return if no fallback servers are available
   234  		res := <-msgCh
   235  		msg, err = res.Msg, res.Error
   236  		return
   237  	}
   238  
   239  	fallbackMsg := r.asyncExchange(ctx, r.fallback, m)
   240  	res := <-msgCh
   241  	if res.Error == nil {
   242  		if ips := msgToIP(res.Msg); len(ips) != 0 {
   243  			if !r.shouldIPFallback(ips[0]) {
   244  				msg = res.Msg // no need to wait for fallback result
   245  				err = res.Error
   246  				return msg, err
   247  			}
   248  		}
   249  	}
   250  
   251  	res = <-fallbackMsg
   252  	msg, err = res.Msg, res.Error
   253  	return
   254  }
   255  
   256  func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) {
   257  	ip = net.ParseIP(host)
   258  	if ip != nil {
   259  		isIPv4 := ip.To4() != nil
   260  		if dnsType == D.TypeAAAA && !isIPv4 {
   261  			return ip, nil
   262  		} else if dnsType == D.TypeA && isIPv4 {
   263  			return ip, nil
   264  		} else {
   265  			return nil, resolver.ErrIPVersion
   266  		}
   267  	}
   268  
   269  	query := &D.Msg{}
   270  	query.SetQuestion(D.Fqdn(host), dnsType)
   271  
   272  	msg, err := r.Exchange(query)
   273  	if err != nil {
   274  		return nil, err
   275  	}
   276  
   277  	ips := msgToIP(msg)
   278  	ipLength := len(ips)
   279  	if ipLength == 0 {
   280  		return nil, resolver.ErrIPNotFound
   281  	}
   282  
   283  	ip = ips[rand.Intn(ipLength)]
   284  	return
   285  }
   286  
   287  func (r *Resolver) msgToDomain(msg *D.Msg) string {
   288  	if len(msg.Question) > 0 {
   289  		return strings.TrimRight(msg.Question[0].Name, ".")
   290  	}
   291  
   292  	return ""
   293  }
   294  
   295  func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D.Msg) <-chan *result {
   296  	ch := make(chan *result, 1)
   297  	go func() {
   298  		res, err := r.batchExchange(ctx, client, msg)
   299  		ch <- &result{Msg: res, Error: err}
   300  	}()
   301  	return ch
   302  }
   303  
   304  type NameServer struct {
   305  	Net       string
   306  	Addr      string
   307  	Interface string
   308  }
   309  
   310  type FallbackFilter struct {
   311  	GeoIP     bool
   312  	GeoIPCode string
   313  	IPCIDR    []*net.IPNet
   314  	Domain    []string
   315  }
   316  
   317  type Config struct {
   318  	Main, Fallback []NameServer
   319  	Default        []NameServer
   320  	IPv6           bool
   321  	EnhancedMode   C.DNSMode
   322  	FallbackFilter FallbackFilter
   323  	Pool           *fakeip.Pool
   324  	Hosts          *trie.DomainTrie
   325  	Policy         map[string]NameServer
   326  }
   327  
   328  func NewResolver(config Config) *Resolver {
   329  	defaultResolver := &Resolver{
   330  		main:     transform(config.Default, nil),
   331  		lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)),
   332  	}
   333  
   334  	r := &Resolver{
   335  		ipv6:     config.IPv6,
   336  		main:     transform(config.Main, defaultResolver),
   337  		lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)),
   338  		hosts:    config.Hosts,
   339  	}
   340  
   341  	if len(config.Fallback) != 0 {
   342  		r.fallback = transform(config.Fallback, defaultResolver)
   343  	}
   344  
   345  	if len(config.Policy) != 0 {
   346  		r.policy = trie.New()
   347  		for domain, nameserver := range config.Policy {
   348  			r.policy.Insert(domain, transform([]NameServer{nameserver}, defaultResolver))
   349  		}
   350  	}
   351  
   352  	fallbackIPFilters := []fallbackIPFilter{}
   353  	if config.FallbackFilter.GeoIP {
   354  		fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{
   355  			code: config.FallbackFilter.GeoIPCode,
   356  		})
   357  	}
   358  	for _, ipnet := range config.FallbackFilter.IPCIDR {
   359  		fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet})
   360  	}
   361  	r.fallbackIPFilters = fallbackIPFilters
   362  
   363  	if len(config.FallbackFilter.Domain) != 0 {
   364  		fallbackDomainFilters := []fallbackDomainFilter{NewDomainFilter(config.FallbackFilter.Domain)}
   365  		r.fallbackDomainFilters = fallbackDomainFilters
   366  	}
   367  
   368  	return r
   369  }