github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/doh.go (about)

     1  package dns
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    17  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    18  	pd "github.com/Asutorufa/yuhaiin/pkg/protos/config/dns"
    19  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    20  	ynet "github.com/Asutorufa/yuhaiin/pkg/utils/net"
    21  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    22  	"github.com/Asutorufa/yuhaiin/pkg/utils/relay"
    23  	"github.com/Asutorufa/yuhaiin/pkg/utils/singleflight"
    24  )
    25  
    26  func init() {
    27  	Register(pd.Type_doh, NewDoH)
    28  }
    29  
    30  func NewDoH(config Config) (netapi.Resolver, error) {
    31  	req, err := getRequest(config.Host)
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	host := req.r.Host
    37  	_, port, err := net.SplitHostPort(req.r.Host)
    38  	if err != nil || port == "" {
    39  		host = net.JoinHostPort(host, "443")
    40  	}
    41  
    42  	addr, err := netapi.ParseAddress(statistic.Type_tcp, host)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	if config.Servername == "" {
    48  		config.Servername = req.Clone(context.TODO(), nil).URL.Hostname()
    49  	}
    50  
    51  	tlsConfig := &tls.Config{
    52  		ServerName: config.Servername,
    53  	}
    54  
    55  	type transportStore struct {
    56  		transport *transport
    57  		time      time.Time
    58  	}
    59  
    60  	roundTripper := atomic.Pointer[transportStore]{}
    61  
    62  	var sf singleflight.Group[struct{}, struct{}]
    63  
    64  	refreshRoundTripper := func() {
    65  		rt := roundTripper.Load()
    66  		if rt != nil {
    67  			if time.Since(rt.time) <= time.Second*5 {
    68  				return
    69  			}
    70  
    71  			rt.transport.Close()
    72  		}
    73  
    74  		_, _, _ = sf.Do(struct{}{}, func() (struct{}, error) {
    75  			roundTripper.Store(&transportStore{
    76  				transport: newTransport(&http.Transport{
    77  					TLSClientConfig:   tlsConfig,
    78  					ForceAttemptHTTP2: true,
    79  					DialContext: func(ctx context.Context, network, host string) (net.Conn, error) {
    80  						return config.Dialer.Conn(ctx, addr)
    81  					},
    82  					MaxIdleConns:          100,
    83  					IdleConnTimeout:       90 * time.Second,
    84  					TLSHandshakeTimeout:   10 * time.Second,
    85  					ExpectContinueTimeout: 1 * time.Second,
    86  				}),
    87  				time: time.Now(),
    88  			})
    89  
    90  			return struct{}{}, nil
    91  		})
    92  	}
    93  
    94  	refreshRoundTripper()
    95  
    96  	return NewClient(config,
    97  		func(ctx context.Context, b []byte) (*pool.Bytes, error) {
    98  			resp, err := roundTripper.Load().transport.RoundTrip(req.Clone(ctx, b))
    99  			if err != nil {
   100  				refreshRoundTripper() // https://github.com/golang/go/issues/30702
   101  				return nil, fmt.Errorf("doh post failed: %w", err)
   102  			}
   103  			defer resp.Body.Close()
   104  
   105  			if resp.StatusCode != http.StatusOK {
   106  				_, _ = relay.Copy(io.Discard, resp.Body) // By consuming the whole body the TLS connection may be reused on the next request.
   107  				return nil, fmt.Errorf("doh post return code: %d", resp.StatusCode)
   108  			}
   109  
   110  			buf := pool.GetBytesBuffer(nat.MaxSegmentSize)
   111  
   112  			_, err = buf.ReadFull(resp.Body)
   113  			if err != nil {
   114  				buf.Free()
   115  				return nil, fmt.Errorf("doh post failed: %w", err)
   116  			}
   117  
   118  			return buf, nil
   119  
   120  			/*
   121  				* Get
   122  				urls := fmt.Sprintf(
   123  					"%s?dns=%s",
   124  					url,
   125  					strings.TrimSuffix(base64.URLEncoding.EncodeToString(dReq), "="),
   126  				)
   127  				resp, err := httpClient.Get(urls)
   128  			*/
   129  		}), nil
   130  }
   131  
   132  // https://tools.ietf.org/html/rfc8484
   133  func getUrlAndHost(host string) string {
   134  	scheme, rest, _ := ynet.GetScheme(host)
   135  	if scheme == "" {
   136  		host = "https://" + host
   137  	}
   138  
   139  	rest = strings.TrimPrefix(rest, "//")
   140  
   141  	if rest == "" {
   142  		host += "no-host-specified"
   143  	}
   144  
   145  	if !strings.Contains(rest, "/") {
   146  		host = host + "/dns-query"
   147  	}
   148  
   149  	return host
   150  }
   151  
   152  type post struct {
   153  	r *http.Request
   154  }
   155  
   156  func getRequest(host string) (*post, error) {
   157  	uri := getUrlAndHost(host)
   158  	req, err := http.NewRequest(http.MethodPost, uri, nil)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	req.Header.Set("Content-Type", "application/dns-message")
   163  	req.Header.Set("Accept", "application/dns-message")
   164  	return &post{req}, nil
   165  }
   166  
   167  func (p *post) Clone(ctx context.Context, body []byte) *http.Request {
   168  	req := p.r.Clone(ctx)
   169  	req.ContentLength = int64(len(body))
   170  	req.Body = io.NopCloser(bytes.NewBuffer(body))
   171  	req.GetBody = func() (io.ReadCloser, error) {
   172  		return io.NopCloser(bytes.NewReader(body)), nil
   173  	}
   174  
   175  	return req
   176  }
   177  
   178  type transport struct {
   179  	*http.Transport
   180  
   181  	mu          sync.Mutex
   182  	conns       []net.Conn
   183  	dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
   184  }
   185  
   186  func newTransport(p *http.Transport) *transport {
   187  	t := &transport{}
   188  
   189  	t.dialContext = p.DialContext
   190  	p.DialContext = t.DialContext
   191  
   192  	t.Transport = p
   193  
   194  	return t
   195  }
   196  
   197  func (t *transport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
   198  	conn, err := t.dialContext(ctx, network, addr)
   199  	if err != nil {
   200  		return nil, err
   201  	}
   202  
   203  	t.mu.Lock()
   204  	t.conns = append(t.conns, conn)
   205  	t.mu.Unlock()
   206  
   207  	return conn, nil
   208  }
   209  
   210  func (t *transport) Close() {
   211  	for _, v := range t.conns {
   212  		_ = v.Close()
   213  	}
   214  	t.Transport.CloseIdleConnections()
   215  }