github.com/yaling888/clash@v1.53.0/dns/doh.go (about)

     1  package dns
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"io"
     8  	"math/rand/v2"
     9  	"net"
    10  	"net/http"
    11  	"net/netip"
    12  	urlPkg "net/url"
    13  	"sync"
    14  	"time"
    15  
    16  	D "github.com/miekg/dns"
    17  
    18  	"github.com/yaling888/clash/component/resolver"
    19  )
    20  
    21  const (
    22  	// dotMimeType is the DoH mimetype that should be used.
    23  	dotMimeType = "application/dns-message"
    24  )
    25  
    26  type contextKey string
    27  
    28  var _ dnsClient = (*dohClient)(nil)
    29  
    30  type dohClient struct {
    31  	r         *Resolver
    32  	url       string
    33  	addr      string
    34  	proxy     string
    35  	urlLog    string
    36  	transport *http.Transport
    37  
    38  	mux            sync.Mutex // guards following fields
    39  	resolved       bool
    40  	proxyTransport map[string]*http.Transport
    41  }
    42  
    43  func (dc *dohClient) IsLan() bool {
    44  	return false
    45  }
    46  
    47  func (dc *dohClient) Exchange(m *D.Msg) (msg *rMsg, err error) {
    48  	return dc.ExchangeContext(context.Background(), m)
    49  }
    50  
    51  func (dc *dohClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *rMsg, err error) {
    52  	dc.mux.Lock()
    53  	if !dc.resolved {
    54  		host, port, _ := net.SplitHostPort(dc.addr)
    55  		ips, err1 := resolver.LookupIPByResolver(context.Background(), host, dc.r)
    56  		if err1 != nil {
    57  			dc.mux.Unlock()
    58  			return nil, err1
    59  		}
    60  
    61  		u, _ := urlPkg.Parse(dc.url)
    62  		addr := net.JoinHostPort(ips[rand.IntN(len(ips))].String(), port)
    63  
    64  		u.Host = addr
    65  		dc.url = u.String()
    66  		dc.addr = addr
    67  		dc.resolved = true
    68  	}
    69  	dc.mux.Unlock()
    70  
    71  	proxy := dc.proxy
    72  	if p, ok := resolver.GetProxy(ctx); ok {
    73  		proxy = p
    74  	}
    75  
    76  	msg = &rMsg{Source: dc.urlLog}
    77  	if proxy != "" {
    78  		msg.Source += "(" + proxy + ")"
    79  		ctx = context.WithValue(ctx, proxyKey, proxy)
    80  	}
    81  
    82  	// https://datatracker.ietf.org/doc/html/rfc8484#section-4.1
    83  	// In order to maximize cache friendliness, SHOULD use a DNS ID of 0 in every DNS request.
    84  	newM := *m
    85  	newM.Id = 0
    86  	req, err := dc.newRequest(&newM)
    87  	if err != nil {
    88  		return msg, err
    89  	}
    90  
    91  	var msg1 *D.Msg
    92  	req = req.WithContext(ctx)
    93  	msg1, err = dc.doRequest(req, proxy)
    94  	if err == nil {
    95  		msg1.Id = m.Id
    96  		msg.Msg = msg1
    97  	}
    98  	return
    99  }
   100  
   101  // newRequest returns a new DoH request given a dns.Msg.
   102  func (dc *dohClient) newRequest(m *D.Msg) (*http.Request, error) {
   103  	buf, err := m.Pack()
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	req, err := http.NewRequest(http.MethodPost, dc.url, bytes.NewReader(buf))
   109  	if err != nil {
   110  		return req, err
   111  	}
   112  
   113  	req.Header.Set("content-type", dotMimeType)
   114  	req.Header.Set("accept", dotMimeType)
   115  	return req, nil
   116  }
   117  
   118  func (dc *dohClient) doRequest(req *http.Request, proxy string) (msg *D.Msg, err error) {
   119  	client1 := &http.Client{Transport: dc.getTransport(proxy)}
   120  	resp, err := client1.Do(req)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  	defer func() {
   125  		_ = resp.Body.Close()
   126  	}()
   127  
   128  	buf, err := io.ReadAll(resp.Body)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	msg = &D.Msg{}
   133  	err = msg.Unpack(buf)
   134  	return msg, err
   135  }
   136  
   137  func (dc *dohClient) getTransport(proxy string) *http.Transport {
   138  	if proxy == "" {
   139  		return dc.transport
   140  	}
   141  
   142  	dc.mux.Lock()
   143  	defer dc.mux.Unlock()
   144  
   145  	if transport, ok := dc.proxyTransport[proxy]; ok {
   146  		return transport
   147  	}
   148  
   149  	transport := &http.Transport{
   150  		ForceAttemptHTTP2:   dc.transport.ForceAttemptHTTP2,
   151  		DialContext:         dc.transport.DialContext,
   152  		TLSClientConfig:     dc.transport.TLSClientConfig.Clone(),
   153  		MaxIdleConnsPerHost: 5,
   154  		IdleConnTimeout:     10 * time.Minute,
   155  	}
   156  
   157  	dc.proxyTransport[proxy] = transport
   158  	return transport
   159  }
   160  
   161  func newDoHClient(url string, proxy string, r *Resolver) *dohClient {
   162  	u, _ := urlPkg.Parse(url)
   163  	host := u.Hostname()
   164  	port := u.Port()
   165  	if port == "" {
   166  		port = "443"
   167  	}
   168  	addr := net.JoinHostPort(host, port)
   169  
   170  	var proxyTransport map[string]*http.Transport
   171  	if proxy != "" {
   172  		proxyTransport = make(map[string]*http.Transport)
   173  	}
   174  
   175  	resolved := false
   176  	if _, err := netip.ParseAddr(host); err == nil {
   177  		resolved = true
   178  	}
   179  
   180  	return &dohClient{
   181  		r:        r,
   182  		url:      url,
   183  		addr:     addr,
   184  		proxy:    proxy,
   185  		urlLog:   url,
   186  		resolved: resolved,
   187  		transport: &http.Transport{
   188  			ForceAttemptHTTP2: true,
   189  			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
   190  				return getTCPConn(ctx, addr)
   191  			},
   192  			TLSClientConfig: &tls.Config{
   193  				ServerName: host,
   194  				NextProtos: []string{"dns"},
   195  			},
   196  			MaxIdleConnsPerHost: 5,
   197  		},
   198  		proxyTransport: proxyTransport,
   199  	}
   200  }