github.com/yaling888/clash@v1.53.0/dns/util.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"net/netip"
     9  	"slices"
    10  	"strconv"
    11  	"time"
    12  
    13  	D "github.com/miekg/dns"
    14  	"github.com/phuslu/log"
    15  	"github.com/samber/lo"
    16  
    17  	"github.com/yaling888/clash/common/cache"
    18  	"github.com/yaling888/clash/common/errors2"
    19  	"github.com/yaling888/clash/common/picker"
    20  	"github.com/yaling888/clash/component/dialer"
    21  	"github.com/yaling888/clash/component/resolver"
    22  	C "github.com/yaling888/clash/constant"
    23  	"github.com/yaling888/clash/tunnel"
    24  )
    25  
    26  const (
    27  	proxyKey     = contextKey("key-dns-client-proxy")
    28  	proxyTimeout = 10 * time.Second
    29  )
    30  
    31  func putMsgToCache(c *cache.LruCache[string, *rMsg], key string, msg *rMsg) {
    32  	putMsgToCacheWithExpire(c, key, msg, 0)
    33  }
    34  
    35  func putMsgToCacheWithExpire(c *cache.LruCache[string, *rMsg], key string, msg *rMsg, sec uint32) {
    36  	if sec == 0 {
    37  		if sec = minTTL(msg.Msg.Answer); sec == 0 {
    38  			if sec = minTTL(msg.Msg.Ns); sec == 0 {
    39  				sec = minTTL(msg.Msg.Extra)
    40  			}
    41  		}
    42  		if sec == 0 {
    43  			return
    44  		}
    45  		if !msg.Lan {
    46  			sec = max(sec, 300) // at least 5 minutes to cache
    47  		}
    48  	}
    49  
    50  	sortAnswer(msg.Msg.Answer)
    51  
    52  	c.SetWithExpire(key, msg.Copy(), time.Now().Add(time.Duration(sec)*time.Second))
    53  }
    54  
    55  func setMsgTTL(msg *D.Msg, ttl uint32) {
    56  	setMsgTTLWithForce(msg, ttl, true)
    57  }
    58  
    59  func setMsgMaxTTL(msg *D.Msg, ttl uint32) {
    60  	setMsgTTLWithForce(msg, ttl, false)
    61  }
    62  
    63  func setMsgTTLWithForce(msg *D.Msg, ttl uint32, force bool) {
    64  	setTTL(msg.Answer, ttl, force)
    65  	setTTL(msg.Ns, ttl, force)
    66  	setTTL(msg.Extra, ttl, force)
    67  }
    68  
    69  func setTTL(records []D.RR, ttl uint32, force bool) {
    70  	if force {
    71  		for i := range records {
    72  			if records[i].Header().Rrtype != D.TypeA &&
    73  				records[i].Header().Rrtype != D.TypeAAAA &&
    74  				records[i].Header().Ttl == 0 {
    75  				continue
    76  			}
    77  			records[i].Header().Ttl = ttl
    78  		}
    79  		return
    80  	}
    81  
    82  	delta := minTTL(records) - ttl
    83  	for i := range records {
    84  		if records[i].Header().Rrtype != D.TypeA &&
    85  			records[i].Header().Rrtype != D.TypeAAAA &&
    86  			records[i].Header().Ttl == 0 {
    87  			continue
    88  		}
    89  		records[i].Header().Ttl = min(max(records[i].Header().Ttl-delta, 1), records[i].Header().Ttl)
    90  	}
    91  }
    92  
    93  func minTTL(records []D.RR) uint32 {
    94  	minObj := lo.MinBy(records, func(r1 D.RR, r2 D.RR) bool {
    95  		return r1.Header().Ttl < r2.Header().Ttl
    96  	})
    97  	if minObj != nil {
    98  		return minObj.Header().Ttl
    99  	}
   100  	return 0
   101  }
   102  
   103  func sortAnswer(answer []D.RR) {
   104  	slices.SortFunc(answer, func(ip1, ip2 D.RR) int {
   105  		var (
   106  			addr1, addr2 netip.Addr
   107  			ok           bool
   108  		)
   109  		switch a := ip1.(type) {
   110  		case *D.A:
   111  			addr1, ok = netip.AddrFromSlice(a.A.To4())
   112  		case *D.AAAA:
   113  			addr1, ok = netip.AddrFromSlice(a.AAAA)
   114  		}
   115  		if !ok {
   116  			addr1 = netip.MustParseAddr("ffff::")
   117  		}
   118  		ok = false
   119  		switch a := ip2.(type) {
   120  		case *D.A:
   121  			addr2, ok = netip.AddrFromSlice(a.A.To4())
   122  		case *D.AAAA:
   123  			addr2, ok = netip.AddrFromSlice(a.AAAA)
   124  		}
   125  		if !ok {
   126  			addr2 = netip.MustParseAddr("ffff::")
   127  		}
   128  		return addr1.Compare(addr2)
   129  	})
   130  }
   131  
   132  func isIPRequest(q D.Question) bool {
   133  	return q.Qclass == D.ClassINET && (q.Qtype == D.TypeA || q.Qtype == D.TypeAAAA)
   134  }
   135  
   136  func transform(servers []NameServer, r *Resolver) []dnsClient {
   137  	var ret []dnsClient
   138  	for _, s := range servers {
   139  		switch s.Net {
   140  		case "https":
   141  			ret = append(ret, newDoHClient(s.Addr, s.Proxy, r))
   142  			continue
   143  		case "dhcp":
   144  			ret = append(ret, newDHCPClient(s.Addr))
   145  			continue
   146  		}
   147  
   148  		ret = append(ret, newClient(s.Net, s.Addr, s.Proxy, s.Interface, s.IsDHCP, r))
   149  	}
   150  	return ret
   151  }
   152  
   153  func handleMsgWithEmptyAnswer(r *D.Msg) *D.Msg {
   154  	msg := &D.Msg{}
   155  	msg.Answer = []D.RR{}
   156  
   157  	msg.SetRcode(r, D.RcodeSuccess)
   158  	msg.Authoritative = true
   159  	msg.RecursionAvailable = true
   160  
   161  	return msg
   162  }
   163  
   164  func msgToIP(msg *D.Msg) []netip.Addr {
   165  	var ips []netip.Addr
   166  
   167  	for _, answer := range msg.Answer {
   168  		switch ans := answer.(type) {
   169  		case *D.AAAA:
   170  			ip, ok := netip.AddrFromSlice(ans.AAAA)
   171  			if !ok {
   172  				continue
   173  			}
   174  			ips = append(ips, ip)
   175  		case *D.A:
   176  			ip, ok := netip.AddrFromSlice(ans.A.To4())
   177  			if !ok {
   178  				continue
   179  			}
   180  			ips = append(ips, ip)
   181  		}
   182  	}
   183  
   184  	return ips
   185  }
   186  
   187  func msgToIPStr(msg D.Msg) []string {
   188  	var ips []string
   189  
   190  	for _, answer := range msg.Answer {
   191  		switch ans := answer.(type) {
   192  		case *D.AAAA:
   193  			ips = append(ips, ans.AAAA.String())
   194  		case *D.A:
   195  			ips = append(ips, ans.A.String())
   196  		}
   197  	}
   198  
   199  	return ips
   200  }
   201  
   202  type wrapPacketConn struct {
   203  	net.PacketConn
   204  	rAddr net.Addr
   205  }
   206  
   207  func (wpc *wrapPacketConn) Read(b []byte) (n int, err error) {
   208  	n, _, err = wpc.PacketConn.ReadFrom(b)
   209  	return n, err
   210  }
   211  
   212  func (wpc *wrapPacketConn) Write(b []byte) (n int, err error) {
   213  	return wpc.PacketConn.WriteTo(b, wpc.rAddr)
   214  }
   215  
   216  func (wpc *wrapPacketConn) RemoteAddr() net.Addr {
   217  	return wpc.rAddr
   218  }
   219  
   220  func dialContextByProxyOrInterface(
   221  	ctx context.Context,
   222  	network string,
   223  	dstIP netip.Addr,
   224  	port string,
   225  	proxyOrInterface string,
   226  	opts ...dialer.Option,
   227  ) (net.Conn, error) {
   228  	proxy, ok := tunnel.FindProxyByName(proxyOrInterface)
   229  	if !ok {
   230  		opts = []dialer.Option{dialer.WithInterface(proxyOrInterface), dialer.WithRoutingMark(0)}
   231  		conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(dstIP.String(), port), opts...)
   232  		if err == nil {
   233  			return conn, nil
   234  		}
   235  		return nil, fmt.Errorf("proxy %s not found, %w", proxyOrInterface, err)
   236  	}
   237  
   238  	networkType := C.TCP
   239  	if network == "udp" {
   240  		networkType = C.UDP
   241  	}
   242  
   243  	p, _ := strconv.ParseUint(port, 10, 16)
   244  	metadata := &C.Metadata{
   245  		NetWork: networkType,
   246  		Host:    "",
   247  		DstIP:   dstIP,
   248  		DstPort: C.Port(p),
   249  	}
   250  
   251  	if networkType == C.UDP {
   252  		if !proxy.SupportUDP() {
   253  			if tunnel.UDPFallbackMatch.Load() {
   254  				return nil, fmt.Errorf("proxy %s UDP is not supported", proxy.Name())
   255  			} else {
   256  				log.Debug().
   257  					Str("proxy", proxy.Name()).
   258  					Msg("[DNS] proxy UDP is not supported, fallback to TCP")
   259  
   260  				metadata.NetWork = C.TCP
   261  				goto tcp
   262  			}
   263  		}
   264  
   265  		packetConn, err := proxy.ListenPacketContext(ctx, metadata, opts...)
   266  		if err != nil {
   267  			return nil, err
   268  		}
   269  
   270  		return &wrapPacketConn{
   271  			PacketConn: packetConn,
   272  			rAddr:      metadata.UDPAddr(),
   273  		}, nil
   274  	}
   275  
   276  tcp:
   277  	return proxy.DialContext(ctx, metadata, opts...)
   278  }
   279  
   280  func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *rMsg, err error) {
   281  	var (
   282  		fast *picker.Picker[*rMsg]
   283  		cs   = clients
   284  	)
   285  
   286  	if _, ok := ctx.Deadline(); ok {
   287  		fast, ctx = picker.WithContext[*rMsg](ctx)
   288  	} else {
   289  		fast, ctx = picker.WithTimeout[*rMsg](ctx, resolver.DefaultDNSTimeout)
   290  	}
   291  
   292  	for i := range cs {
   293  		r := cs[i]
   294  		fast.Go(func() (*rMsg, error) {
   295  			mm, fErr := r.ExchangeContext(ctx, m)
   296  			go logDnsResponse(m.Question[0], mm, fErr)
   297  			if fErr != nil {
   298  				return nil, fErr
   299  			} else if mm.Msg.Rcode == D.RcodeServerFailure || mm.Msg.Rcode == D.RcodeRefused {
   300  				return nil, errors.New("server failure")
   301  			}
   302  			return mm, nil
   303  		})
   304  	}
   305  
   306  	elm := fast.Wait()
   307  	if elm == nil {
   308  		err = errors.New("all DNS requests failed")
   309  		if fErr := fast.Error(); fErr != nil {
   310  			err = errors.Join(err, fErr)
   311  		}
   312  		return nil, errors2.Cause(err)
   313  	}
   314  
   315  	return elm, nil
   316  }
   317  
   318  func genMsgCacheKey(ctx context.Context, q D.Question) string {
   319  	if proxy, ok := resolver.GetProxy(ctx); ok && proxy != "" {
   320  		return fmt.Sprintf("%s:%s:%d:%d", proxy, q.Name, q.Qtype, q.Qclass)
   321  	}
   322  	return fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass)
   323  }
   324  
   325  func getTCPConn(ctx context.Context, addr string) (conn net.Conn, err error) {
   326  	if proxy, ok := ctx.Value(proxyKey).(string); ok {
   327  		host, port, _ := net.SplitHostPort(addr)
   328  		ip, err1 := netip.ParseAddr(host)
   329  		if err1 != nil {
   330  			return nil, err1
   331  		}
   332  		conn, err = dialContextByProxyOrInterface(ctx, "tcp", ip, port, proxy)
   333  	} else {
   334  		conn, err = dialer.DialContext(ctx, "tcp", addr)
   335  	}
   336  
   337  	if err == nil {
   338  		if c, ok := conn.(*net.TCPConn); ok {
   339  			_ = c.SetKeepAlive(true)
   340  		}
   341  	}
   342  	return
   343  }
   344  
   345  func logDnsResponse(q D.Question, msg *rMsg, err error) {
   346  	if msg == nil {
   347  		return
   348  	}
   349  	if q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA {
   350  		return
   351  	}
   352  
   353  	if err != nil && !errors.Is(err, context.Canceled) {
   354  		log.Debug().
   355  			Err(err).
   356  			Str("source", msg.Source).
   357  			Str("qType", D.Type(q.Qtype).String()).
   358  			Str("name", q.Name).
   359  			Msg("[DNS] dns response failed")
   360  	} else if msg.Msg != nil {
   361  		log.Debug().
   362  			Str("source", msg.Source).
   363  			Str("qType", D.Type(q.Qtype).String()).
   364  			Str("name", q.Name).
   365  			EmbedObject(LogAnswer{ans: *msg.Msg}).
   366  			Uint32("ttl", minTTL(msg.Msg.Answer)).
   367  			Msg("[DNS] dns response")
   368  	}
   369  }
   370  
   371  type LogAnswer struct {
   372  	ans D.Msg
   373  }
   374  
   375  func (l LogAnswer) MarshalObject(e *log.Entry) {
   376  	e.Strs("answer", msgToIPStr(l.ans))
   377  }