github.com/metacubex/mihomo@v1.18.5/dns/util.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"net/netip"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	N "github.com/metacubex/mihomo/common/net"
    15  	"github.com/metacubex/mihomo/common/nnip"
    16  	"github.com/metacubex/mihomo/common/picker"
    17  	"github.com/metacubex/mihomo/component/dialer"
    18  	"github.com/metacubex/mihomo/component/resolver"
    19  	C "github.com/metacubex/mihomo/constant"
    20  	"github.com/metacubex/mihomo/log"
    21  	"github.com/metacubex/mihomo/tunnel"
    22  
    23  	D "github.com/miekg/dns"
    24  	"github.com/samber/lo"
    25  )
    26  
    27  const (
    28  	MaxMsgSize = 65535
    29  )
    30  
    31  const serverFailureCacheTTL uint32 = 5
    32  
    33  func minimalTTL(records []D.RR) uint32 {
    34  	rr := lo.MinBy(records, func(r1 D.RR, r2 D.RR) bool {
    35  		return r1.Header().Ttl < r2.Header().Ttl
    36  	})
    37  	if rr == nil {
    38  		return 0
    39  	}
    40  	return rr.Header().Ttl
    41  }
    42  
    43  func updateTTL(records []D.RR, ttl uint32) {
    44  	if len(records) == 0 {
    45  		return
    46  	}
    47  	delta := minimalTTL(records) - ttl
    48  	for i := range records {
    49  		records[i].Header().Ttl = lo.Clamp(records[i].Header().Ttl-delta, 1, records[i].Header().Ttl)
    50  	}
    51  }
    52  
    53  func putMsgToCache(c dnsCache, key string, q D.Question, msg *D.Msg) {
    54  	// skip dns cache for acme challenge
    55  	if q.Qtype == D.TypeTXT && strings.HasPrefix(q.Name, "_acme-challenge.") {
    56  		log.Debugln("[DNS] dns cache ignored because of acme challenge for: %s", q.Name)
    57  		return
    58  	}
    59  
    60  	var ttl uint32
    61  	if msg.Rcode == D.RcodeServerFailure {
    62  		// [...] a resolver MAY cache a server failure response.
    63  		// If it does so it MUST NOT cache it for longer than five (5) minutes [...]
    64  		ttl = serverFailureCacheTTL
    65  	} else {
    66  		ttl = minimalTTL(append(append(msg.Answer, msg.Ns...), msg.Extra...))
    67  	}
    68  	if ttl == 0 {
    69  		return
    70  	}
    71  	c.SetWithExpire(key, msg.Copy(), time.Now().Add(time.Duration(ttl)*time.Second))
    72  }
    73  
    74  func setMsgTTL(msg *D.Msg, ttl uint32) {
    75  	for _, answer := range msg.Answer {
    76  		answer.Header().Ttl = ttl
    77  	}
    78  
    79  	for _, ns := range msg.Ns {
    80  		ns.Header().Ttl = ttl
    81  	}
    82  
    83  	for _, extra := range msg.Extra {
    84  		extra.Header().Ttl = ttl
    85  	}
    86  }
    87  
    88  func updateMsgTTL(msg *D.Msg, ttl uint32) {
    89  	updateTTL(msg.Answer, ttl)
    90  	updateTTL(msg.Ns, ttl)
    91  	updateTTL(msg.Extra, ttl)
    92  }
    93  
    94  func isIPRequest(q D.Question) bool {
    95  	return q.Qclass == D.ClassINET && (q.Qtype == D.TypeA || q.Qtype == D.TypeAAAA || q.Qtype == D.TypeCNAME)
    96  }
    97  
    98  func transform(servers []NameServer, resolver *Resolver) []dnsClient {
    99  	ret := make([]dnsClient, 0, len(servers))
   100  	for _, s := range servers {
   101  		switch s.Net {
   102  		case "https":
   103  			ret = append(ret, newDoHClient(s.Addr, resolver, s.PreferH3, s.Params, s.ProxyAdapter, s.ProxyName))
   104  			continue
   105  		case "dhcp":
   106  			ret = append(ret, newDHCPClient(s.Addr))
   107  			continue
   108  		case "system":
   109  			ret = append(ret, newSystemClient())
   110  			continue
   111  		case "rcode":
   112  			ret = append(ret, newRCodeClient(s.Addr))
   113  			continue
   114  		case "quic":
   115  			if doq, err := newDoQ(resolver, s.Addr, s.ProxyAdapter, s.ProxyName); err == nil {
   116  				ret = append(ret, doq)
   117  			} else {
   118  				log.Fatalln("DoQ format error: %v", err)
   119  			}
   120  			continue
   121  		}
   122  
   123  		host, port, _ := net.SplitHostPort(s.Addr)
   124  		ret = append(ret, &client{
   125  			Client: &D.Client{
   126  				Net: s.Net,
   127  				TLSConfig: &tls.Config{
   128  					ServerName: host,
   129  				},
   130  				UDPSize: 4096,
   131  				Timeout: 5 * time.Second,
   132  			},
   133  			port:         port,
   134  			host:         host,
   135  			iface:        s.Interface,
   136  			r:            resolver,
   137  			proxyAdapter: s.ProxyAdapter,
   138  			proxyName:    s.ProxyName,
   139  		})
   140  	}
   141  	return ret
   142  }
   143  
   144  func handleMsgWithEmptyAnswer(r *D.Msg) *D.Msg {
   145  	msg := &D.Msg{}
   146  	msg.Answer = []D.RR{}
   147  
   148  	msg.SetRcode(r, D.RcodeSuccess)
   149  	msg.Authoritative = true
   150  	msg.RecursionAvailable = true
   151  
   152  	return msg
   153  }
   154  
   155  func msgToIP(msg *D.Msg) []netip.Addr {
   156  	ips := []netip.Addr{}
   157  
   158  	for _, answer := range msg.Answer {
   159  		switch ans := answer.(type) {
   160  		case *D.AAAA:
   161  			ips = append(ips, nnip.IpToAddr(ans.AAAA))
   162  		case *D.A:
   163  			ips = append(ips, nnip.IpToAddr(ans.A))
   164  		}
   165  	}
   166  
   167  	return ips
   168  }
   169  
   170  func msgToDomain(msg *D.Msg) string {
   171  	if len(msg.Question) > 0 {
   172  		return strings.TrimRight(msg.Question[0].Name, ".")
   173  	}
   174  
   175  	return ""
   176  }
   177  
   178  type dialHandler func(ctx context.Context, network, addr string) (net.Conn, error)
   179  
   180  func getDialHandler(r *Resolver, proxyAdapter C.ProxyAdapter, proxyName string, opts ...dialer.Option) dialHandler {
   181  	return func(ctx context.Context, network, addr string) (net.Conn, error) {
   182  		if len(proxyName) == 0 && proxyAdapter == nil {
   183  			opts = append(opts, dialer.WithResolver(r))
   184  			return dialer.DialContext(ctx, network, addr, opts...)
   185  		} else {
   186  			host, port, err := net.SplitHostPort(addr)
   187  			if err != nil {
   188  				return nil, err
   189  			}
   190  			uintPort, err := strconv.ParseUint(port, 10, 16)
   191  			if err != nil {
   192  				return nil, err
   193  			}
   194  			if proxyAdapter == nil {
   195  				var ok bool
   196  				proxyAdapter, ok = tunnel.Proxies()[proxyName]
   197  				if !ok {
   198  					opts = append(opts, dialer.WithInterface(proxyName))
   199  				}
   200  			}
   201  
   202  			if strings.Contains(network, "tcp") {
   203  				// tcp can resolve host by remote
   204  				metadata := &C.Metadata{
   205  					NetWork: C.TCP,
   206  					Host:    host,
   207  					DstPort: uint16(uintPort),
   208  				}
   209  				if proxyAdapter != nil {
   210  					if proxyAdapter.IsL3Protocol(metadata) { // L3 proxy should resolve domain before to avoid loopback
   211  						dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r)
   212  						if err != nil {
   213  							return nil, err
   214  						}
   215  						metadata.Host = ""
   216  						metadata.DstIP = dstIP
   217  					}
   218  					return proxyAdapter.DialContext(ctx, metadata, opts...)
   219  				}
   220  				opts = append(opts, dialer.WithResolver(r))
   221  				return dialer.DialContext(ctx, network, addr, opts...)
   222  			} else {
   223  				// udp must resolve host first
   224  				dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r)
   225  				if err != nil {
   226  					return nil, err
   227  				}
   228  				metadata := &C.Metadata{
   229  					NetWork: C.UDP,
   230  					Host:    "",
   231  					DstIP:   dstIP,
   232  					DstPort: uint16(uintPort),
   233  				}
   234  				if proxyAdapter == nil {
   235  					return dialer.DialContext(ctx, network, addr, opts...)
   236  				}
   237  
   238  				if !proxyAdapter.SupportUDP() {
   239  					return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", proxyAdapter)
   240  				}
   241  
   242  				packetConn, err := proxyAdapter.ListenPacketContext(ctx, metadata, opts...)
   243  				if err != nil {
   244  					return nil, err
   245  				}
   246  
   247  				return N.NewBindPacketConn(packetConn, metadata.UDPAddr()), nil
   248  			}
   249  		}
   250  	}
   251  }
   252  
   253  func listenPacket(ctx context.Context, proxyAdapter C.ProxyAdapter, proxyName string, network string, addr string, r *Resolver, opts ...dialer.Option) (net.PacketConn, error) {
   254  	host, port, err := net.SplitHostPort(addr)
   255  	if err != nil {
   256  		return nil, err
   257  	}
   258  	uintPort, err := strconv.ParseUint(port, 10, 16)
   259  	if err != nil {
   260  		return nil, err
   261  	}
   262  	if proxyAdapter == nil {
   263  		var ok bool
   264  		proxyAdapter, ok = tunnel.Proxies()[proxyName]
   265  		if !ok {
   266  			opts = append(opts, dialer.WithInterface(proxyName))
   267  		}
   268  	}
   269  
   270  	// udp must resolve host first
   271  	dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r)
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  	metadata := &C.Metadata{
   276  		NetWork: C.UDP,
   277  		Host:    "",
   278  		DstIP:   dstIP,
   279  		DstPort: uint16(uintPort),
   280  	}
   281  	if proxyAdapter == nil {
   282  		return dialer.NewDialer(opts...).ListenPacket(ctx, network, "", netip.AddrPortFrom(metadata.DstIP, metadata.DstPort))
   283  	}
   284  
   285  	if !proxyAdapter.SupportUDP() {
   286  		return nil, fmt.Errorf("proxy adapter [%s] UDP is not supported", proxyAdapter)
   287  	}
   288  
   289  	return proxyAdapter.ListenPacketContext(ctx, metadata, opts...)
   290  }
   291  
   292  func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, cache bool, err error) {
   293  	cache = true
   294  	fast, ctx := picker.WithTimeout[*D.Msg](ctx, resolver.DefaultDNSTimeout)
   295  	defer fast.Close()
   296  	domain := msgToDomain(m)
   297  	var noIpMsg *D.Msg
   298  	for _, client := range clients {
   299  		if _, isRCodeClient := client.(rcodeClient); isRCodeClient {
   300  			msg, err = client.ExchangeContext(ctx, m)
   301  			return msg, false, err
   302  		}
   303  		client := client // shadow define client to ensure the value captured by the closure will not be changed in the next loop
   304  		fast.Go(func() (*D.Msg, error) {
   305  			log.Debugln("[DNS] resolve %s from %s", domain, client.Address())
   306  			m, err := client.ExchangeContext(ctx, m)
   307  			if err != nil {
   308  				return nil, err
   309  			} else if cache && (m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused) {
   310  				// currently, cache indicates whether this msg was from a RCode client,
   311  				// so we would ignore RCode errors from RCode clients.
   312  				return nil, errors.New("server failure: " + D.RcodeToString[m.Rcode])
   313  			}
   314  			if ips := msgToIP(m); len(m.Question) > 0 {
   315  				qType := m.Question[0].Qtype
   316  				log.Debugln("[DNS] %s --> %s %s from %s", domain, ips, D.Type(qType), client.Address())
   317  				switch qType {
   318  				case D.TypeAAAA:
   319  					if len(ips) == 0 {
   320  						noIpMsg = m
   321  						return nil, resolver.ErrIPNotFound
   322  					}
   323  				case D.TypeA:
   324  					if len(ips) == 0 {
   325  						noIpMsg = m
   326  						return nil, resolver.ErrIPNotFound
   327  					}
   328  				}
   329  			}
   330  			return m, nil
   331  		})
   332  	}
   333  
   334  	msg = fast.Wait()
   335  	if msg == nil {
   336  		if noIpMsg != nil {
   337  			return noIpMsg, false, nil
   338  		}
   339  		err = errors.New("all DNS requests failed")
   340  		if fErr := fast.Error(); fErr != nil {
   341  			err = fmt.Errorf("%w, first error: %w", err, fErr)
   342  		}
   343  	}
   344  	return
   345  }