github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/doh.go (about) 1 package dns 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "fmt" 8 "io" 9 "net" 10 "net/http" 11 "strings" 12 "sync" 13 "sync/atomic" 14 "time" 15 16 "github.com/Asutorufa/yuhaiin/pkg/net/nat" 17 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 18 pd "github.com/Asutorufa/yuhaiin/pkg/protos/config/dns" 19 "github.com/Asutorufa/yuhaiin/pkg/protos/statistic" 20 ynet "github.com/Asutorufa/yuhaiin/pkg/utils/net" 21 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 22 "github.com/Asutorufa/yuhaiin/pkg/utils/relay" 23 "github.com/Asutorufa/yuhaiin/pkg/utils/singleflight" 24 ) 25 26 func init() { 27 Register(pd.Type_doh, NewDoH) 28 } 29 30 func NewDoH(config Config) (netapi.Resolver, error) { 31 req, err := getRequest(config.Host) 32 if err != nil { 33 return nil, err 34 } 35 36 host := req.r.Host 37 _, port, err := net.SplitHostPort(req.r.Host) 38 if err != nil || port == "" { 39 host = net.JoinHostPort(host, "443") 40 } 41 42 addr, err := netapi.ParseAddress(statistic.Type_tcp, host) 43 if err != nil { 44 return nil, err 45 } 46 47 if config.Servername == "" { 48 config.Servername = req.Clone(context.TODO(), nil).URL.Hostname() 49 } 50 51 tlsConfig := &tls.Config{ 52 ServerName: config.Servername, 53 } 54 55 type transportStore struct { 56 transport *transport 57 time time.Time 58 } 59 60 roundTripper := atomic.Pointer[transportStore]{} 61 62 var sf singleflight.Group[struct{}, struct{}] 63 64 refreshRoundTripper := func() { 65 rt := roundTripper.Load() 66 if rt != nil { 67 if time.Since(rt.time) <= time.Second*5 { 68 return 69 } 70 71 rt.transport.Close() 72 } 73 74 _, _, _ = sf.Do(struct{}{}, func() (struct{}, error) { 75 roundTripper.Store(&transportStore{ 76 transport: newTransport(&http.Transport{ 77 TLSClientConfig: tlsConfig, 78 ForceAttemptHTTP2: true, 79 DialContext: func(ctx context.Context, network, host string) (net.Conn, error) { 80 return config.Dialer.Conn(ctx, addr) 81 }, 82 MaxIdleConns: 100, 83 IdleConnTimeout: 90 * time.Second, 84 TLSHandshakeTimeout: 10 * time.Second, 85 ExpectContinueTimeout: 1 * time.Second, 86 }), 87 time: time.Now(), 88 }) 89 90 return struct{}{}, nil 91 }) 92 } 93 94 refreshRoundTripper() 95 96 return NewClient(config, 97 func(ctx context.Context, b []byte) (*pool.Bytes, error) { 98 resp, err := roundTripper.Load().transport.RoundTrip(req.Clone(ctx, b)) 99 if err != nil { 100 refreshRoundTripper() // https://github.com/golang/go/issues/30702 101 return nil, fmt.Errorf("doh post failed: %w", err) 102 } 103 defer resp.Body.Close() 104 105 if resp.StatusCode != http.StatusOK { 106 _, _ = relay.Copy(io.Discard, resp.Body) // By consuming the whole body the TLS connection may be reused on the next request. 107 return nil, fmt.Errorf("doh post return code: %d", resp.StatusCode) 108 } 109 110 buf := pool.GetBytesBuffer(nat.MaxSegmentSize) 111 112 _, err = buf.ReadFull(resp.Body) 113 if err != nil { 114 buf.Free() 115 return nil, fmt.Errorf("doh post failed: %w", err) 116 } 117 118 return buf, nil 119 120 /* 121 * Get 122 urls := fmt.Sprintf( 123 "%s?dns=%s", 124 url, 125 strings.TrimSuffix(base64.URLEncoding.EncodeToString(dReq), "="), 126 ) 127 resp, err := httpClient.Get(urls) 128 */ 129 }), nil 130 } 131 132 // https://tools.ietf.org/html/rfc8484 133 func getUrlAndHost(host string) string { 134 scheme, rest, _ := ynet.GetScheme(host) 135 if scheme == "" { 136 host = "https://" + host 137 } 138 139 rest = strings.TrimPrefix(rest, "//") 140 141 if rest == "" { 142 host += "no-host-specified" 143 } 144 145 if !strings.Contains(rest, "/") { 146 host = host + "/dns-query" 147 } 148 149 return host 150 } 151 152 type post struct { 153 r *http.Request 154 } 155 156 func getRequest(host string) (*post, error) { 157 uri := getUrlAndHost(host) 158 req, err := http.NewRequest(http.MethodPost, uri, nil) 159 if err != nil { 160 return nil, err 161 } 162 req.Header.Set("Content-Type", "application/dns-message") 163 req.Header.Set("Accept", "application/dns-message") 164 return &post{req}, nil 165 } 166 167 func (p *post) Clone(ctx context.Context, body []byte) *http.Request { 168 req := p.r.Clone(ctx) 169 req.ContentLength = int64(len(body)) 170 req.Body = io.NopCloser(bytes.NewBuffer(body)) 171 req.GetBody = func() (io.ReadCloser, error) { 172 return io.NopCloser(bytes.NewReader(body)), nil 173 } 174 175 return req 176 } 177 178 type transport struct { 179 *http.Transport 180 181 mu sync.Mutex 182 conns []net.Conn 183 dialContext func(ctx context.Context, network, addr string) (net.Conn, error) 184 } 185 186 func newTransport(p *http.Transport) *transport { 187 t := &transport{} 188 189 t.dialContext = p.DialContext 190 p.DialContext = t.DialContext 191 192 t.Transport = p 193 194 return t 195 } 196 197 func (t *transport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 198 conn, err := t.dialContext(ctx, network, addr) 199 if err != nil { 200 return nil, err 201 } 202 203 t.mu.Lock() 204 t.conns = append(t.conns, conn) 205 t.mu.Unlock() 206 207 return conn, nil 208 } 209 210 func (t *transport) Close() { 211 for _, v := range t.conns { 212 _ = v.Close() 213 } 214 t.Transport.CloseIdleConnections() 215 }