github.com/metacubex/mihomo@v1.18.5/dns/client.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"net"
     8  	"net/netip"
     9  	"strings"
    10  
    11  	"github.com/metacubex/mihomo/component/ca"
    12  	"github.com/metacubex/mihomo/component/dialer"
    13  	"github.com/metacubex/mihomo/component/resolver"
    14  	C "github.com/metacubex/mihomo/constant"
    15  	"github.com/metacubex/mihomo/log"
    16  
    17  	D "github.com/miekg/dns"
    18  	"github.com/zhangyunhao116/fastrand"
    19  )
    20  
    21  type client struct {
    22  	*D.Client
    23  	r            *Resolver
    24  	port         string
    25  	host         string
    26  	iface        string
    27  	proxyAdapter C.ProxyAdapter
    28  	proxyName    string
    29  	addr         string
    30  }
    31  
    32  var _ dnsClient = (*client)(nil)
    33  
    34  // Address implements dnsClient
    35  func (c *client) Address() string {
    36  	if len(c.addr) != 0 {
    37  		return c.addr
    38  	}
    39  	schema := "udp"
    40  	if strings.HasPrefix(c.Client.Net, "tcp") {
    41  		schema = "tcp"
    42  		if strings.HasSuffix(c.Client.Net, "tls") {
    43  			schema = "tls"
    44  		}
    45  	}
    46  
    47  	c.addr = fmt.Sprintf("%s://%s", schema, net.JoinHostPort(c.host, c.port))
    48  	return c.addr
    49  }
    50  
    51  func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) {
    52  	var (
    53  		ip  netip.Addr
    54  		err error
    55  	)
    56  	if c.r == nil {
    57  		// a default ip dns
    58  		if ip, err = netip.ParseAddr(c.host); err != nil {
    59  			return nil, fmt.Errorf("dns %s not a valid ip", c.host)
    60  		}
    61  	} else {
    62  		ips, err := resolver.LookupIPWithResolver(ctx, c.host, c.r)
    63  		if err != nil {
    64  			return nil, fmt.Errorf("use default dns resolve failed: %w", err)
    65  		} else if len(ips) == 0 {
    66  			return nil, fmt.Errorf("%w: %s", resolver.ErrIPNotFound, c.host)
    67  		}
    68  		ip = ips[fastrand.Intn(len(ips))]
    69  	}
    70  
    71  	network := "udp"
    72  	if strings.HasPrefix(c.Client.Net, "tcp") {
    73  		network = "tcp"
    74  	}
    75  
    76  	var options []dialer.Option
    77  	if c.iface != "" {
    78  		options = append(options, dialer.WithInterface(c.iface))
    79  	}
    80  
    81  	dialHandler := getDialHandler(c.r, c.proxyAdapter, c.proxyName, options...)
    82  	addr := net.JoinHostPort(ip.String(), c.port)
    83  	conn, err := dialHandler(ctx, network, addr)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	defer func() {
    88  		_ = conn.Close()
    89  	}()
    90  
    91  	// miekg/dns ExchangeContext doesn't respond to context cancel.
    92  	// this is a workaround
    93  	type result struct {
    94  		msg *D.Msg
    95  		err error
    96  	}
    97  	ch := make(chan result, 1)
    98  	go func() {
    99  		if strings.HasSuffix(c.Client.Net, "tls") {
   100  			conn = tls.Client(conn, ca.GetGlobalTLSConfig(c.Client.TLSConfig))
   101  		}
   102  
   103  		dConn := &D.Conn{
   104  			Conn:         conn,
   105  			UDPSize:      c.Client.UDPSize,
   106  			TsigSecret:   c.Client.TsigSecret,
   107  			TsigProvider: c.Client.TsigProvider,
   108  		}
   109  
   110  		msg, _, err := c.Client.ExchangeWithConn(m, dConn)
   111  
   112  		// Resolvers MUST resend queries over TCP if they receive a truncated UDP response (with TC=1 set)!
   113  		if msg != nil && msg.Truncated && c.Client.Net == "" {
   114  			tcpClient := *c.Client // copy a client
   115  			tcpClient.Net = "tcp"
   116  			network = "tcp"
   117  			log.Debugln("[DNS] Truncated reply from %s:%s for %s over UDP, retrying over TCP", c.host, c.port, m.Question[0].String())
   118  			dConn.Conn, err = dialHandler(ctx, network, addr)
   119  			if err != nil {
   120  				ch <- result{msg, err}
   121  				return
   122  			}
   123  			defer func() {
   124  				_ = conn.Close()
   125  			}()
   126  			msg, _, err = tcpClient.ExchangeWithConn(m, dConn)
   127  		}
   128  
   129  		ch <- result{msg, err}
   130  	}()
   131  
   132  	select {
   133  	case <-ctx.Done():
   134  		return nil, ctx.Err()
   135  	case ret := <-ch:
   136  		return ret.msg, ret.err
   137  	}
   138  }