github.com/v2fly/v2ray-core/v4@v4.45.2/app/dns/nameserver_doh.go (about) 1 //go:build !confonly 2 // +build !confonly 3 4 package dns 5 6 import ( 7 "bytes" 8 "context" 9 "fmt" 10 "io" 11 "net/http" 12 "net/url" 13 "sync" 14 "sync/atomic" 15 "time" 16 17 "golang.org/x/net/dns/dnsmessage" 18 19 "github.com/v2fly/v2ray-core/v4/common" 20 "github.com/v2fly/v2ray-core/v4/common/net" 21 "github.com/v2fly/v2ray-core/v4/common/protocol/dns" 22 "github.com/v2fly/v2ray-core/v4/common/session" 23 "github.com/v2fly/v2ray-core/v4/common/signal/pubsub" 24 "github.com/v2fly/v2ray-core/v4/common/task" 25 dns_feature "github.com/v2fly/v2ray-core/v4/features/dns" 26 "github.com/v2fly/v2ray-core/v4/features/routing" 27 "github.com/v2fly/v2ray-core/v4/transport/internet" 28 ) 29 30 // DoHNameServer implemented DNS over HTTPS (RFC8484) Wire Format, 31 // which is compatible with traditional dns over udp(RFC1035), 32 // thus most of the DOH implementation is copied from udpns.go 33 type DoHNameServer struct { 34 sync.RWMutex 35 ips map[string]*record 36 pub *pubsub.Service 37 cleanup *task.Periodic 38 reqID uint32 39 httpClient *http.Client 40 dohURL string 41 name string 42 } 43 44 // NewDoHNameServer creates DOH server object for remote resolving. 45 func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher) (*DoHNameServer, error) { 46 newError("DNS: created Remote DOH client for ", url.String()).AtInfo().WriteToLog() 47 s := baseDOHNameServer(url, "DOH") 48 49 // Dispatched connection will be closed (interrupted) after each request 50 // This makes DOH inefficient without a keep-alived connection 51 // See: core/app/proxyman/outbound/handler.go:113 52 // Using mux (https request wrapped in a stream layer) improves the situation. 53 // Recommend to use NewDoHLocalNameServer (DOHL:) if v2ray instance is running on 54 // a normal network eg. the server side of v2ray 55 tr := &http.Transport{ 56 MaxIdleConns: 30, 57 IdleConnTimeout: 90 * time.Second, 58 TLSHandshakeTimeout: 30 * time.Second, 59 ForceAttemptHTTP2: true, 60 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 61 dest, err := net.ParseDestination(network + ":" + addr) 62 if err != nil { 63 return nil, err 64 } 65 66 link, err := dispatcher.Dispatch(ctx, dest) 67 if err != nil { 68 return nil, err 69 } 70 return net.NewConnection( 71 net.ConnectionInputMulti(link.Writer), 72 net.ConnectionOutputMulti(link.Reader), 73 ), nil 74 }, 75 } 76 77 dispatchedClient := &http.Client{ 78 Transport: tr, 79 Timeout: 60 * time.Second, 80 } 81 82 s.httpClient = dispatchedClient 83 return s, nil 84 } 85 86 // NewDoHLocalNameServer creates DOH client object for local resolving 87 func NewDoHLocalNameServer(url *url.URL) *DoHNameServer { 88 url.Scheme = "https" 89 s := baseDOHNameServer(url, "DOHL") 90 tr := &http.Transport{ 91 IdleConnTimeout: 90 * time.Second, 92 ForceAttemptHTTP2: true, 93 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 94 dest, err := net.ParseDestination(network + ":" + addr) 95 if err != nil { 96 return nil, err 97 } 98 conn, err := internet.DialSystem(ctx, dest, nil) 99 if err != nil { 100 return nil, err 101 } 102 return conn, nil 103 }, 104 } 105 s.httpClient = &http.Client{ 106 Timeout: time.Second * 180, 107 Transport: tr, 108 } 109 newError("DNS: created Local DOH client for ", url.String()).AtInfo().WriteToLog() 110 return s 111 } 112 113 func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer { 114 s := &DoHNameServer{ 115 ips: make(map[string]*record), 116 pub: pubsub.NewService(), 117 name: prefix + "//" + url.Host, 118 dohURL: url.String(), 119 } 120 s.cleanup = &task.Periodic{ 121 Interval: time.Minute, 122 Execute: s.Cleanup, 123 } 124 return s 125 } 126 127 // Name implements Server. 128 func (s *DoHNameServer) Name() string { 129 return s.name 130 } 131 132 // Cleanup clears expired items from cache 133 func (s *DoHNameServer) Cleanup() error { 134 now := time.Now() 135 s.Lock() 136 defer s.Unlock() 137 138 if len(s.ips) == 0 { 139 return newError("nothing to do. stopping...") 140 } 141 142 for domain, record := range s.ips { 143 if record.A != nil && record.A.Expire.Before(now) { 144 record.A = nil 145 } 146 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 147 record.AAAA = nil 148 } 149 150 if record.A == nil && record.AAAA == nil { 151 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 152 delete(s.ips, domain) 153 } else { 154 s.ips[domain] = record 155 } 156 } 157 158 if len(s.ips) == 0 { 159 s.ips = make(map[string]*record) 160 } 161 162 return nil 163 } 164 165 func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 166 elapsed := time.Since(req.start) 167 168 s.Lock() 169 rec, found := s.ips[req.domain] 170 if !found { 171 rec = &record{} 172 } 173 updated := false 174 175 switch req.reqType { 176 case dnsmessage.TypeA: 177 if isNewer(rec.A, ipRec) { 178 rec.A = ipRec 179 updated = true 180 } 181 case dnsmessage.TypeAAAA: 182 addr := make([]net.Address, 0, len(ipRec.IP)) 183 for _, ip := range ipRec.IP { 184 if len(ip.IP()) == net.IPv6len { 185 addr = append(addr, ip) 186 } 187 } 188 ipRec.IP = addr 189 if isNewer(rec.AAAA, ipRec) { 190 rec.AAAA = ipRec 191 updated = true 192 } 193 } 194 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 195 196 if updated { 197 s.ips[req.domain] = rec 198 } 199 switch req.reqType { 200 case dnsmessage.TypeA: 201 s.pub.Publish(req.domain+"4", nil) 202 case dnsmessage.TypeAAAA: 203 s.pub.Publish(req.domain+"6", nil) 204 } 205 s.Unlock() 206 common.Must(s.cleanup.Start()) 207 } 208 209 func (s *DoHNameServer) newReqID() uint16 { 210 return uint16(atomic.AddUint32(&s.reqID, 1)) 211 } 212 213 func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 214 newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx)) 215 216 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 217 218 var deadline time.Time 219 if d, ok := ctx.Deadline(); ok { 220 deadline = d 221 } else { 222 deadline = time.Now().Add(time.Second * 5) 223 } 224 225 for _, req := range reqs { 226 go func(r *dnsRequest) { 227 // generate new context for each req, using same context 228 // may cause reqs all aborted if any one encounter an error 229 dnsCtx := ctx 230 231 // reserve internal dns server requested Inbound 232 if inbound := session.InboundFromContext(ctx); inbound != nil { 233 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 234 } 235 236 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 237 Protocol: "https", 238 SkipDNSResolve: true, 239 }) 240 241 // forced to use mux for DOH 242 dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true) 243 244 var cancel context.CancelFunc 245 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 246 defer cancel() 247 248 b, err := dns.PackMessage(r.msg) 249 if err != nil { 250 newError("failed to pack dns query").Base(err).AtError().WriteToLog() 251 return 252 } 253 resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes()) 254 if err != nil { 255 newError("failed to retrieve response").Base(err).AtError().WriteToLog() 256 return 257 } 258 rec, err := parseResponse(resp) 259 if err != nil { 260 newError("failed to handle DOH response").Base(err).AtError().WriteToLog() 261 return 262 } 263 s.updateIP(r, rec) 264 }(req) 265 } 266 } 267 268 func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) { 269 body := bytes.NewBuffer(b) 270 req, err := http.NewRequest("POST", s.dohURL, body) 271 if err != nil { 272 return nil, err 273 } 274 275 req.Header.Add("Accept", "application/dns-message") 276 req.Header.Add("Content-Type", "application/dns-message") 277 278 resp, err := s.httpClient.Do(req.WithContext(ctx)) 279 if err != nil { 280 return nil, err 281 } 282 283 defer resp.Body.Close() 284 if resp.StatusCode != http.StatusOK { 285 io.Copy(io.Discard, resp.Body) // flush resp.Body so that the conn is reusable 286 return nil, fmt.Errorf("DOH server returned code %d", resp.StatusCode) 287 } 288 289 return io.ReadAll(resp.Body) 290 } 291 292 func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 293 s.RLock() 294 record, found := s.ips[domain] 295 s.RUnlock() 296 297 if !found { 298 return nil, errRecordNotFound 299 } 300 301 var err4 error 302 var err6 error 303 var ips []net.Address 304 var ip6 []net.Address 305 306 if option.IPv4Enable { 307 ips, err4 = record.A.getIPs() 308 } 309 310 if option.IPv6Enable { 311 ip6, err6 = record.AAAA.getIPs() 312 ips = append(ips, ip6...) 313 } 314 315 if len(ips) > 0 { 316 return toNetIP(ips) 317 } 318 319 if err4 != nil { 320 return nil, err4 321 } 322 323 if err6 != nil { 324 return nil, err6 325 } 326 327 if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { 328 return nil, dns_feature.ErrEmptyResponse 329 } 330 331 return nil, errRecordNotFound 332 } 333 334 // QueryIP implements Server. 335 func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { // nolint: dupl 336 fqdn := Fqdn(domain) 337 338 if disableCache { 339 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 340 } else { 341 ips, err := s.findIPsForDomain(fqdn, option) 342 if err != errRecordNotFound { 343 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 344 return ips, err 345 } 346 } 347 348 // ipv4 and ipv6 belong to different subscription groups 349 var sub4, sub6 *pubsub.Subscriber 350 if option.IPv4Enable { 351 sub4 = s.pub.Subscribe(fqdn + "4") 352 defer sub4.Close() 353 } 354 if option.IPv6Enable { 355 sub6 = s.pub.Subscribe(fqdn + "6") 356 defer sub6.Close() 357 } 358 done := make(chan interface{}) 359 go func() { 360 if sub4 != nil { 361 select { 362 case <-sub4.Wait(): 363 case <-ctx.Done(): 364 } 365 } 366 if sub6 != nil { 367 select { 368 case <-sub6.Wait(): 369 case <-ctx.Done(): 370 } 371 } 372 close(done) 373 }() 374 s.sendQuery(ctx, fqdn, clientIP, option) 375 376 for { 377 ips, err := s.findIPsForDomain(fqdn, option) 378 if err != errRecordNotFound { 379 return ips, err 380 } 381 382 select { 383 case <-ctx.Done(): 384 return nil, ctx.Err() 385 case <-done: 386 } 387 } 388 }