github.com/laof/lite-speed-test@v0.0.0-20230930011949-1f39b7037845/dns/resolver.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"math/rand"
     9  	"net"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/laof/lite-speed-test/common/cache"
    14  	"github.com/laof/lite-speed-test/common/picker"
    15  	"github.com/laof/lite-speed-test/transport/resolver"
    16  	"golang.org/x/sync/singleflight"
    17  
    18  	// "github.com/Dreamacro/clash/component/trie"
    19  
    20  	D "github.com/miekg/dns"
    21  )
    22  
    23  var (
    24  	globalSessionCache = tls.NewLRUClientSessionCache(64)
    25  )
    26  
    27  type dnsClient interface {
    28  	Exchange(m *D.Msg) (msg *D.Msg, err error)
    29  	ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error)
    30  }
    31  
    32  type result struct {
    33  	Msg   *D.Msg
    34  	Error error
    35  }
    36  
    37  type Resolver struct {
    38  	ipv6     bool
    39  	main     []dnsClient
    40  	fallback []dnsClient
    41  	group    singleflight.Group
    42  	lruCache *cache.LruCache
    43  }
    44  
    45  // ResolveIP request with TypeA and TypeAAAA, priority return TypeA
    46  func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) {
    47  	ch := make(chan net.IP, 1)
    48  	go func() {
    49  		defer close(ch)
    50  		ip, err := r.resolveIP(host, D.TypeAAAA)
    51  		if err != nil {
    52  			return
    53  		}
    54  		ch <- ip
    55  	}()
    56  
    57  	ip, err = r.resolveIP(host, D.TypeA)
    58  	if err == nil && ip != nil {
    59  		return
    60  	}
    61  
    62  	ip, open := <-ch
    63  	if !open {
    64  		return nil, resolver.ErrIPNotFound
    65  	}
    66  
    67  	return ip, nil
    68  }
    69  
    70  // ResolveIPv4 request with TypeA
    71  func (r *Resolver) ResolveIPv4(host string) (ip net.IP, err error) {
    72  	return r.resolveIP(host, D.TypeA)
    73  }
    74  
    75  // ResolveIPv6 request with TypeAAAA
    76  func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) {
    77  	return r.resolveIP(host, D.TypeAAAA)
    78  }
    79  
    80  func (r *Resolver) shouldIPFallback(ip net.IP) bool {
    81  	return false
    82  }
    83  
    84  // Exchange a batch of dns request, and it use cache
    85  func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) {
    86  	if len(m.Question) == 0 {
    87  		return nil, errors.New("should have one question at least")
    88  	}
    89  
    90  	q := m.Question[0]
    91  	cache, expireTime, hit := r.lruCache.GetWithExpire(q.String())
    92  	if hit {
    93  		now := time.Now()
    94  		msg = cache.(*D.Msg).Copy()
    95  		if expireTime.Before(now) {
    96  			setMsgTTL(msg, uint32(1)) // Continue fetch
    97  			go r.exchangeWithoutCache(m)
    98  		} else {
    99  			setMsgTTL(msg, uint32(time.Until(expireTime).Seconds()))
   100  		}
   101  		return
   102  	}
   103  	return r.exchangeWithoutCache(m)
   104  }
   105  
   106  // ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache
   107  func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) {
   108  	q := m.Question[0]
   109  
   110  	ret, err, shared := r.group.Do(q.String(), func() (result interface{}, err error) {
   111  		defer func() {
   112  			if err != nil {
   113  				return
   114  			}
   115  
   116  			msg := result.(*D.Msg)
   117  
   118  			putMsgToCache(r.lruCache, q.String(), msg)
   119  		}()
   120  
   121  		isIPReq := isIPRequest(q)
   122  		if isIPReq {
   123  			return r.ipExchange(m)
   124  		}
   125  
   126  		return r.batchExchange(r.main, m)
   127  	})
   128  
   129  	if err == nil {
   130  		msg = ret.(*D.Msg)
   131  		if shared {
   132  			msg = msg.Copy()
   133  		}
   134  	}
   135  
   136  	return
   137  }
   138  
   139  func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {
   140  	fast, ctx := picker.WithTimeout(context.Background(), time.Second*5)
   141  	for _, client := range clients {
   142  		r := client
   143  		fast.Go(func() (interface{}, error) {
   144  			m, err := r.ExchangeContext(ctx, m)
   145  			if err != nil {
   146  				return nil, err
   147  			} else if m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused {
   148  				return nil, errors.New("server failure")
   149  			}
   150  			return m, nil
   151  		})
   152  	}
   153  
   154  	elm := fast.Wait()
   155  	if elm == nil {
   156  		err := errors.New("all DNS requests failed")
   157  		if fErr := fast.Error(); fErr != nil {
   158  			err = fmt.Errorf("%w, first error: %s", err, fErr.Error())
   159  		}
   160  		return nil, err
   161  	}
   162  
   163  	msg = elm.(*D.Msg)
   164  	return
   165  }
   166  
   167  func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
   168  	if r.fallback == nil {
   169  		return false
   170  	}
   171  
   172  	domain := r.msgToDomain(m)
   173  
   174  	if domain == "" {
   175  		return false
   176  	}
   177  
   178  	return false
   179  }
   180  
   181  func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
   182  
   183  	onlyFallback := r.shouldOnlyQueryFallback(m)
   184  
   185  	if onlyFallback {
   186  		res := <-r.asyncExchange(r.fallback, m)
   187  		return res.Msg, res.Error
   188  	}
   189  
   190  	msgCh := r.asyncExchange(r.main, m)
   191  
   192  	if r.fallback == nil { // directly return if no fallback servers are available
   193  		res := <-msgCh
   194  		msg, err = res.Msg, res.Error
   195  		return
   196  	}
   197  
   198  	fallbackMsg := r.asyncExchange(r.fallback, m)
   199  	res := <-msgCh
   200  	if res.Error == nil {
   201  		if ips := r.msgToIP(res.Msg); len(ips) != 0 {
   202  			if !r.shouldIPFallback(ips[0]) {
   203  				msg = res.Msg // no need to wait for fallback result
   204  				err = res.Error
   205  				return msg, err
   206  			}
   207  		}
   208  	}
   209  
   210  	res = <-fallbackMsg
   211  	msg, err = res.Msg, res.Error
   212  	return
   213  }
   214  
   215  func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) {
   216  	ip = net.ParseIP(host)
   217  	if ip != nil {
   218  		isIPv4 := ip.To4() != nil
   219  		if dnsType == D.TypeAAAA && !isIPv4 {
   220  			return ip, nil
   221  		} else if dnsType == D.TypeA && isIPv4 {
   222  			return ip, nil
   223  		} else {
   224  			return nil, resolver.ErrIPVersion
   225  		}
   226  	}
   227  
   228  	query := &D.Msg{}
   229  	query.SetQuestion(D.Fqdn(host), dnsType)
   230  
   231  	msg, err := r.Exchange(query)
   232  	if err != nil {
   233  		return nil, err
   234  	}
   235  
   236  	ips := r.msgToIP(msg)
   237  	ipLength := len(ips)
   238  	if ipLength == 0 {
   239  		return nil, resolver.ErrIPNotFound
   240  	}
   241  
   242  	ip = ips[rand.Intn(ipLength)]
   243  	return
   244  }
   245  
   246  func (r *Resolver) msgToIP(msg *D.Msg) []net.IP {
   247  	ips := []net.IP{}
   248  
   249  	for _, answer := range msg.Answer {
   250  		switch ans := answer.(type) {
   251  		case *D.AAAA:
   252  			ips = append(ips, ans.AAAA)
   253  		case *D.A:
   254  			ips = append(ips, ans.A)
   255  		}
   256  	}
   257  
   258  	return ips
   259  }
   260  
   261  func (r *Resolver) msgToDomain(msg *D.Msg) string {
   262  	if len(msg.Question) > 0 {
   263  		return strings.TrimRight(msg.Question[0].Name, ".")
   264  	}
   265  
   266  	return ""
   267  }
   268  
   269  func (r *Resolver) asyncExchange(client []dnsClient, msg *D.Msg) <-chan *result {
   270  	ch := make(chan *result, 1)
   271  	go func() {
   272  		res, err := r.batchExchange(client, msg)
   273  		ch <- &result{Msg: res, Error: err}
   274  	}()
   275  	return ch
   276  }
   277  
   278  type NameServer struct {
   279  	Net  string
   280  	Addr string
   281  }
   282  
   283  type FallbackFilter struct {
   284  	GeoIP  bool
   285  	IPCIDR []*net.IPNet
   286  	Domain []string
   287  }
   288  
   289  type Config struct {
   290  	Main, Fallback []NameServer
   291  	Default        []NameServer
   292  	IPv6           bool
   293  	FallbackFilter FallbackFilter
   294  	// Pool           *fakeip.Pool
   295  	// Hosts *trie.DomainTrie
   296  }
   297  
   298  func NewResolver(config Config) *Resolver {
   299  	defaultResolver := &Resolver{
   300  		main:     transform(config.Default, nil),
   301  		lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)),
   302  	}
   303  
   304  	r := &Resolver{
   305  		ipv6:     config.IPv6,
   306  		lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)),
   307  		main:     transform(config.Main, defaultResolver),
   308  	}
   309  
   310  	if len(config.Fallback) != 0 {
   311  		r.fallback = transform(config.Fallback, defaultResolver)
   312  	}
   313  
   314  	return r
   315  }
   316  
   317  func DefaultResolver() *Resolver {
   318  	servers := []NameServer{
   319  		{
   320  			Net:  "udp",
   321  			Addr: "223.5.5.5:53",
   322  		},
   323  		{
   324  			Net:  "udp",
   325  			Addr: "8.8.8.8:53",
   326  		},
   327  	}
   328  	c := Config{
   329  		Main:    servers,
   330  		Default: servers,
   331  	}
   332  	return NewResolver(c)
   333  }