github.com/EagleQL/Xray-core@v1.4.3/app/dns/dohdns.go (about) 1 package dns 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "net/http" 10 "net/url" 11 "sync" 12 "sync/atomic" 13 "time" 14 15 "github.com/xtls/xray-core/common" 16 "github.com/xtls/xray-core/common/log" 17 "github.com/xtls/xray-core/common/net" 18 "github.com/xtls/xray-core/common/net/cnc" 19 "github.com/xtls/xray-core/common/protocol/dns" 20 "github.com/xtls/xray-core/common/session" 21 "github.com/xtls/xray-core/common/signal/pubsub" 22 "github.com/xtls/xray-core/common/task" 23 dns_feature "github.com/xtls/xray-core/features/dns" 24 "github.com/xtls/xray-core/features/routing" 25 "github.com/xtls/xray-core/transport/internet" 26 "golang.org/x/net/dns/dnsmessage" 27 ) 28 29 // DoHNameServer implemented DNS over HTTPS (RFC8484) Wire Format, 30 // which is compatible with traditional dns over udp(RFC1035), 31 // thus most of the DOH implementation is copied from udpns.go 32 type DoHNameServer struct { 33 dispatcher routing.Dispatcher 34 sync.RWMutex 35 ips map[string]record 36 pub *pubsub.Service 37 cleanup *task.Periodic 38 reqID uint32 39 clientIP net.IP 40 httpClient *http.Client 41 dohURL string 42 name string 43 } 44 45 // NewDoHNameServer creates DOH client object for remote resolving 46 func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) { 47 newError("DNS: created Remote DOH client for ", url.String()).AtInfo().WriteToLog() 48 s := baseDOHNameServer(url, "DOH", clientIP) 49 50 s.dispatcher = dispatcher 51 tr := &http.Transport{ 52 MaxIdleConns: 30, 53 IdleConnTimeout: 90 * time.Second, 54 TLSHandshakeTimeout: 30 * time.Second, 55 ForceAttemptHTTP2: true, 56 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 57 dispatcherCtx := context.Background() 58 59 dest, err := net.ParseDestination(network + ":" + addr) 60 if err != nil { 61 return nil, err 62 } 63 64 dispatcherCtx = session.ContextWithContent(dispatcherCtx, session.ContentFromContext(ctx)) 65 dispatcherCtx = session.ContextWithInbound(dispatcherCtx, session.InboundFromContext(ctx)) 66 dispatcherCtx = log.ContextWithAccessMessage(dispatcherCtx, &log.AccessMessage{ 67 From: "DoH", 68 To: s.dohURL, 69 Status: log.AccessAccepted, 70 Reason: "", 71 }) 72 73 link, err := s.dispatcher.Dispatch(dispatcherCtx, dest) 74 select { 75 case <-ctx.Done(): 76 return nil, ctx.Err() 77 default: 78 79 } 80 if err != nil { 81 return nil, err 82 } 83 84 cc := common.ChainedClosable{} 85 if cw, ok := link.Writer.(common.Closable); ok { 86 cc = append(cc, cw) 87 } 88 if cr, ok := link.Reader.(common.Closable); ok { 89 cc = append(cc, cr) 90 } 91 return cnc.NewConnection( 92 cnc.ConnectionInputMulti(link.Writer), 93 cnc.ConnectionOutputMulti(link.Reader), 94 cnc.ConnectionOnClose(cc), 95 ), nil 96 }, 97 } 98 s.httpClient = &http.Client{ 99 Timeout: time.Second * 180, 100 Transport: tr, 101 } 102 103 return s, nil 104 } 105 106 // NewDoHLocalNameServer creates DOH client object for local resolving 107 func NewDoHLocalNameServer(url *url.URL, clientIP net.IP) *DoHNameServer { 108 url.Scheme = "https" 109 s := baseDOHNameServer(url, "DOHL", clientIP) 110 tr := &http.Transport{ 111 IdleConnTimeout: 90 * time.Second, 112 ForceAttemptHTTP2: true, 113 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 114 dest, err := net.ParseDestination(network + ":" + addr) 115 if err != nil { 116 return nil, err 117 } 118 conn, err := internet.DialSystem(ctx, dest, nil) 119 log.Record(&log.AccessMessage{ 120 From: "DoH", 121 To: s.dohURL, 122 Status: log.AccessAccepted, 123 Detour: "local", 124 }) 125 if err != nil { 126 return nil, err 127 } 128 return conn, nil 129 }, 130 } 131 s.httpClient = &http.Client{ 132 Timeout: time.Second * 180, 133 Transport: tr, 134 } 135 newError("DNS: created Local DOH client for ", url.String()).AtInfo().WriteToLog() 136 return s 137 } 138 139 func baseDOHNameServer(url *url.URL, prefix string, clientIP net.IP) *DoHNameServer { 140 s := &DoHNameServer{ 141 ips: make(map[string]record), 142 clientIP: clientIP, 143 pub: pubsub.NewService(), 144 name: prefix + "//" + url.Host, 145 dohURL: url.String(), 146 } 147 s.cleanup = &task.Periodic{ 148 Interval: time.Minute, 149 Execute: s.Cleanup, 150 } 151 152 return s 153 } 154 155 // Name returns client name 156 func (s *DoHNameServer) Name() string { 157 return s.name 158 } 159 160 // Cleanup clears expired items from cache 161 func (s *DoHNameServer) Cleanup() error { 162 now := time.Now() 163 s.Lock() 164 defer s.Unlock() 165 166 if len(s.ips) == 0 { 167 return newError("nothing to do. stopping...") 168 } 169 170 for domain, record := range s.ips { 171 if record.A != nil && record.A.Expire.Before(now) { 172 record.A = nil 173 } 174 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 175 record.AAAA = nil 176 } 177 178 if record.A == nil && record.AAAA == nil { 179 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 180 delete(s.ips, domain) 181 } else { 182 s.ips[domain] = record 183 } 184 } 185 186 if len(s.ips) == 0 { 187 s.ips = make(map[string]record) 188 } 189 190 return nil 191 } 192 193 func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 194 elapsed := time.Since(req.start) 195 196 s.Lock() 197 rec := s.ips[req.domain] 198 updated := false 199 200 switch req.reqType { 201 case dnsmessage.TypeA: 202 if isNewer(rec.A, ipRec) { 203 rec.A = ipRec 204 updated = true 205 } 206 case dnsmessage.TypeAAAA: 207 addr := make([]net.Address, 0) 208 for _, ip := range ipRec.IP { 209 if len(ip.IP()) == net.IPv6len { 210 addr = append(addr, ip) 211 } 212 } 213 ipRec.IP = addr 214 if isNewer(rec.AAAA, ipRec) { 215 rec.AAAA = ipRec 216 updated = true 217 } 218 } 219 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 220 221 if updated { 222 s.ips[req.domain] = rec 223 } 224 switch req.reqType { 225 case dnsmessage.TypeA: 226 s.pub.Publish(req.domain+"4", nil) 227 case dnsmessage.TypeAAAA: 228 s.pub.Publish(req.domain+"6", nil) 229 } 230 s.Unlock() 231 common.Must(s.cleanup.Start()) 232 } 233 234 func (s *DoHNameServer) newReqID() uint16 { 235 return uint16(atomic.AddUint32(&s.reqID, 1)) 236 } 237 238 func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, option dns_feature.IPOption) { 239 newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx)) 240 241 if s.name+"." == "DOH//"+domain { 242 newError(s.name, " tries to resolve itself! Use IP or set \"hosts\" instead.").AtError().WriteToLog(session.ExportIDToError(ctx)) 243 return 244 } 245 246 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP)) 247 248 var deadline time.Time 249 if d, ok := ctx.Deadline(); ok { 250 deadline = d 251 } else { 252 deadline = time.Now().Add(time.Second * 5) 253 } 254 255 for _, req := range reqs { 256 go func(r *dnsRequest) { 257 // generate new context for each req, using same context 258 // may cause reqs all aborted if any one encounter an error 259 dnsCtx := context.Background() 260 261 // reserve internal dns server requested Inbound 262 if inbound := session.InboundFromContext(ctx); inbound != nil { 263 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 264 } 265 266 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 267 Protocol: "https", 268 //SkipRoutePick: true, 269 }) 270 271 // forced to use mux for DOH 272 // dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true) 273 274 var cancel context.CancelFunc 275 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 276 defer cancel() 277 278 b, err := dns.PackMessage(r.msg) 279 if err != nil { 280 newError("failed to pack dns query for ", domain).Base(err).AtError().WriteToLog() 281 return 282 } 283 resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes()) 284 if err != nil { 285 newError("failed to retrieve response for ", domain).Base(err).AtError().WriteToLog() 286 return 287 } 288 rec, err := parseResponse(resp) 289 if err != nil { 290 newError("failed to handle DOH response for ", domain).Base(err).AtError().WriteToLog() 291 return 292 } 293 s.updateIP(r, rec) 294 }(req) 295 } 296 } 297 298 func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) { 299 body := bytes.NewBuffer(b) 300 req, err := http.NewRequest("POST", s.dohURL, body) 301 if err != nil { 302 return nil, err 303 } 304 305 req.Header.Add("Accept", "application/dns-message") 306 req.Header.Add("Content-Type", "application/dns-message") 307 308 hc := s.httpClient 309 310 resp, err := hc.Do(req.WithContext(ctx)) 311 if err != nil { 312 return nil, err 313 } 314 315 defer resp.Body.Close() 316 if resp.StatusCode != http.StatusOK { 317 io.Copy(ioutil.Discard, resp.Body) // flush resp.Body so that the conn is reusable 318 return nil, fmt.Errorf("DOH server returned code %d", resp.StatusCode) 319 } 320 321 return ioutil.ReadAll(resp.Body) 322 } 323 324 func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 325 s.RLock() 326 record, found := s.ips[domain] 327 s.RUnlock() 328 329 if !found { 330 return nil, errRecordNotFound 331 } 332 333 var ips []net.Address 334 var lastErr error 335 if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess { 336 aaaa, err := record.AAAA.getIPs() 337 if err != nil { 338 lastErr = err 339 } 340 ips = append(ips, aaaa...) 341 } 342 343 if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess { 344 a, err := record.A.getIPs() 345 if err != nil { 346 lastErr = err 347 } 348 ips = append(ips, a...) 349 } 350 351 if len(ips) > 0 { 352 return toNetIP(ips), nil 353 } 354 355 if lastErr != nil { 356 return nil, lastErr 357 } 358 359 if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { 360 return nil, dns_feature.ErrEmptyResponse 361 } 362 363 return nil, errRecordNotFound 364 } 365 366 // QueryIP is called from dns.Server->queryIPTimeout 367 func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, error) { // nolint: dupl 368 fqdn := Fqdn(domain) 369 370 ips, err := s.findIPsForDomain(fqdn, option) 371 if err != errRecordNotFound { 372 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 373 log.Record(&log.DNSLog{s.name, domain, ips, log.DNSCacheHit, 0, err}) 374 return ips, err 375 } 376 377 // ipv4 and ipv6 belong to different subscription groups 378 var sub4, sub6 *pubsub.Subscriber 379 if option.IPv4Enable { 380 sub4 = s.pub.Subscribe(fqdn + "4") 381 defer sub4.Close() 382 } 383 if option.IPv6Enable { 384 sub6 = s.pub.Subscribe(fqdn + "6") 385 defer sub6.Close() 386 } 387 done := make(chan interface{}) 388 go func() { 389 if sub4 != nil { 390 select { 391 case <-sub4.Wait(): 392 case <-ctx.Done(): 393 } 394 } 395 if sub6 != nil { 396 select { 397 case <-sub6.Wait(): 398 case <-ctx.Done(): 399 } 400 } 401 close(done) 402 }() 403 s.sendQuery(ctx, fqdn, option) 404 start := time.Now() 405 406 for { 407 ips, err := s.findIPsForDomain(fqdn, option) 408 if err != errRecordNotFound { 409 log.Record(&log.DNSLog{s.name, domain, ips, log.DNSQueried, time.Since(start), err}) 410 return ips, err 411 } 412 413 select { 414 case <-ctx.Done(): 415 return nil, ctx.Err() 416 case <-done: 417 } 418 } 419 }