github.com/moqsien/xraycore@v1.8.5/app/dns/nameserver_doh.go (about) 1 package dns 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "net/http" 9 "net/url" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/moqsien/xraycore/common" 15 "github.com/moqsien/xraycore/common/log" 16 "github.com/moqsien/xraycore/common/net" 17 "github.com/moqsien/xraycore/common/net/cnc" 18 "github.com/moqsien/xraycore/common/protocol/dns" 19 "github.com/moqsien/xraycore/common/session" 20 "github.com/moqsien/xraycore/common/signal/pubsub" 21 "github.com/moqsien/xraycore/common/task" 22 dns_feature "github.com/moqsien/xraycore/features/dns" 23 "github.com/moqsien/xraycore/features/routing" 24 "github.com/moqsien/xraycore/transport/internet" 25 "golang.org/x/net/dns/dnsmessage" 26 ) 27 28 // DoHNameServer implemented DNS over HTTPS (RFC8484) Wire Format, 29 // which is compatible with traditional dns over udp(RFC1035), 30 // thus most of the DOH implementation is copied from udpns.go 31 type DoHNameServer struct { 32 dispatcher routing.Dispatcher 33 sync.RWMutex 34 ips map[string]*record 35 pub *pubsub.Service 36 cleanup *task.Periodic 37 reqID uint32 38 httpClient *http.Client 39 dohURL string 40 name string 41 } 42 43 // NewDoHNameServer creates DOH server object for remote resolving. 44 func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher) (*DoHNameServer, error) { 45 newError("DNS: created Remote DOH client for ", url.String()).AtInfo().WriteToLog() 46 s := baseDOHNameServer(url, "DOH") 47 48 s.dispatcher = dispatcher 49 tr := &http.Transport{ 50 MaxIdleConns: 30, 51 IdleConnTimeout: 90 * time.Second, 52 TLSHandshakeTimeout: 30 * time.Second, 53 ForceAttemptHTTP2: true, 54 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 55 dest, err := net.ParseDestination(network + ":" + addr) 56 if err != nil { 57 return nil, err 58 } 59 link, err := s.dispatcher.Dispatch(toDnsContext(ctx, s.dohURL), dest) 60 select { 61 case <-ctx.Done(): 62 return nil, ctx.Err() 63 default: 64 65 } 66 if err != nil { 67 return nil, err 68 } 69 70 cc := common.ChainedClosable{} 71 if cw, ok := link.Writer.(common.Closable); ok { 72 cc = append(cc, cw) 73 } 74 if cr, ok := link.Reader.(common.Closable); ok { 75 cc = append(cc, cr) 76 } 77 return cnc.NewConnection( 78 cnc.ConnectionInputMulti(link.Writer), 79 cnc.ConnectionOutputMulti(link.Reader), 80 cnc.ConnectionOnClose(cc), 81 ), nil 82 }, 83 } 84 s.httpClient = &http.Client{ 85 Timeout: time.Second * 180, 86 Transport: tr, 87 } 88 89 return s, nil 90 } 91 92 // NewDoHLocalNameServer creates DOH client object for local resolving 93 func NewDoHLocalNameServer(url *url.URL) *DoHNameServer { 94 url.Scheme = "https" 95 s := baseDOHNameServer(url, "DOHL") 96 tr := &http.Transport{ 97 IdleConnTimeout: 90 * time.Second, 98 ForceAttemptHTTP2: true, 99 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 100 dest, err := net.ParseDestination(network + ":" + addr) 101 if err != nil { 102 return nil, err 103 } 104 conn, err := internet.DialSystem(ctx, dest, nil) 105 log.Record(&log.AccessMessage{ 106 From: "DNS", 107 To: s.dohURL, 108 Status: log.AccessAccepted, 109 Detour: "local", 110 }) 111 if err != nil { 112 return nil, err 113 } 114 return conn, nil 115 }, 116 } 117 s.httpClient = &http.Client{ 118 Timeout: time.Second * 180, 119 Transport: tr, 120 } 121 newError("DNS: created Local DOH client for ", url.String()).AtInfo().WriteToLog() 122 return s 123 } 124 125 func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer { 126 s := &DoHNameServer{ 127 ips: make(map[string]*record), 128 pub: pubsub.NewService(), 129 name: prefix + "//" + url.Host, 130 dohURL: url.String(), 131 } 132 s.cleanup = &task.Periodic{ 133 Interval: time.Minute, 134 Execute: s.Cleanup, 135 } 136 return s 137 } 138 139 // Name implements Server. 140 func (s *DoHNameServer) Name() string { 141 return s.name 142 } 143 144 // Cleanup clears expired items from cache 145 func (s *DoHNameServer) Cleanup() error { 146 now := time.Now() 147 s.Lock() 148 defer s.Unlock() 149 150 if len(s.ips) == 0 { 151 return newError("nothing to do. stopping...") 152 } 153 154 for domain, record := range s.ips { 155 if record.A != nil && record.A.Expire.Before(now) { 156 record.A = nil 157 } 158 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 159 record.AAAA = nil 160 } 161 162 if record.A == nil && record.AAAA == nil { 163 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 164 delete(s.ips, domain) 165 } else { 166 s.ips[domain] = record 167 } 168 } 169 170 if len(s.ips) == 0 { 171 s.ips = make(map[string]*record) 172 } 173 174 return nil 175 } 176 177 func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 178 elapsed := time.Since(req.start) 179 180 s.Lock() 181 rec, found := s.ips[req.domain] 182 if !found { 183 rec = &record{} 184 } 185 updated := false 186 187 switch req.reqType { 188 case dnsmessage.TypeA: 189 if isNewer(rec.A, ipRec) { 190 rec.A = ipRec 191 updated = true 192 } 193 case dnsmessage.TypeAAAA: 194 addr := make([]net.Address, 0, len(ipRec.IP)) 195 for _, ip := range ipRec.IP { 196 if len(ip.IP()) == net.IPv6len { 197 addr = append(addr, ip) 198 } 199 } 200 ipRec.IP = addr 201 if isNewer(rec.AAAA, ipRec) { 202 rec.AAAA = ipRec 203 updated = true 204 } 205 } 206 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 207 208 if updated { 209 s.ips[req.domain] = rec 210 } 211 switch req.reqType { 212 case dnsmessage.TypeA: 213 s.pub.Publish(req.domain+"4", nil) 214 case dnsmessage.TypeAAAA: 215 s.pub.Publish(req.domain+"6", nil) 216 } 217 s.Unlock() 218 common.Must(s.cleanup.Start()) 219 } 220 221 func (s *DoHNameServer) newReqID() uint16 { 222 return uint16(atomic.AddUint32(&s.reqID, 1)) 223 } 224 225 func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 226 newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx)) 227 228 if s.name+"." == "DOH//"+domain { 229 newError(s.name, " tries to resolve itself! Use IP or set \"hosts\" instead.").AtError().WriteToLog(session.ExportIDToError(ctx)) 230 return 231 } 232 233 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 234 235 var deadline time.Time 236 if d, ok := ctx.Deadline(); ok { 237 deadline = d 238 } else { 239 deadline = time.Now().Add(time.Second * 5) 240 } 241 242 for _, req := range reqs { 243 go func(r *dnsRequest) { 244 // generate new context for each req, using same context 245 // may cause reqs all aborted if any one encounter an error 246 dnsCtx := ctx 247 248 // reserve internal dns server requested Inbound 249 if inbound := session.InboundFromContext(ctx); inbound != nil { 250 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 251 } 252 253 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 254 Protocol: "https", 255 SkipDNSResolve: true, 256 }) 257 258 // forced to use mux for DOH 259 // dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true) 260 261 var cancel context.CancelFunc 262 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 263 defer cancel() 264 265 b, err := dns.PackMessage(r.msg) 266 if err != nil { 267 newError("failed to pack dns query for ", domain).Base(err).AtError().WriteToLog() 268 return 269 } 270 resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes()) 271 if err != nil { 272 newError("failed to retrieve response for ", domain).Base(err).AtError().WriteToLog() 273 return 274 } 275 rec, err := parseResponse(resp) 276 if err != nil { 277 newError("failed to handle DOH response for ", domain).Base(err).AtError().WriteToLog() 278 return 279 } 280 s.updateIP(r, rec) 281 }(req) 282 } 283 } 284 285 func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) { 286 body := bytes.NewBuffer(b) 287 req, err := http.NewRequest("POST", s.dohURL, body) 288 if err != nil { 289 return nil, err 290 } 291 292 req.Header.Add("Accept", "application/dns-message") 293 req.Header.Add("Content-Type", "application/dns-message") 294 295 hc := s.httpClient 296 297 resp, err := hc.Do(req.WithContext(ctx)) 298 if err != nil { 299 return nil, err 300 } 301 302 defer resp.Body.Close() 303 if resp.StatusCode != http.StatusOK { 304 io.Copy(io.Discard, resp.Body) // flush resp.Body so that the conn is reusable 305 return nil, fmt.Errorf("DOH server returned code %d", resp.StatusCode) 306 } 307 308 return io.ReadAll(resp.Body) 309 } 310 311 func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 312 s.RLock() 313 record, found := s.ips[domain] 314 s.RUnlock() 315 316 if !found { 317 return nil, errRecordNotFound 318 } 319 320 var err4 error 321 var err6 error 322 var ips []net.Address 323 var ip6 []net.Address 324 325 if option.IPv4Enable { 326 ips, err4 = record.A.getIPs() 327 } 328 329 if option.IPv6Enable { 330 ip6, err6 = record.AAAA.getIPs() 331 ips = append(ips, ip6...) 332 } 333 334 if len(ips) > 0 { 335 return toNetIP(ips) 336 } 337 338 if err4 != nil { 339 return nil, err4 340 } 341 342 if err6 != nil { 343 return nil, err6 344 } 345 346 if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { 347 return nil, dns_feature.ErrEmptyResponse 348 } 349 350 return nil, errRecordNotFound 351 } 352 353 // QueryIP implements Server. 354 func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { // nolint: dupl 355 fqdn := Fqdn(domain) 356 357 if disableCache { 358 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 359 } else { 360 ips, err := s.findIPsForDomain(fqdn, option) 361 if err != errRecordNotFound { 362 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 363 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) 364 return ips, err 365 } 366 } 367 368 // ipv4 and ipv6 belong to different subscription groups 369 var sub4, sub6 *pubsub.Subscriber 370 if option.IPv4Enable { 371 sub4 = s.pub.Subscribe(fqdn + "4") 372 defer sub4.Close() 373 } 374 if option.IPv6Enable { 375 sub6 = s.pub.Subscribe(fqdn + "6") 376 defer sub6.Close() 377 } 378 done := make(chan interface{}) 379 go func() { 380 if sub4 != nil { 381 select { 382 case <-sub4.Wait(): 383 case <-ctx.Done(): 384 } 385 } 386 if sub6 != nil { 387 select { 388 case <-sub6.Wait(): 389 case <-ctx.Done(): 390 } 391 } 392 close(done) 393 }() 394 s.sendQuery(ctx, fqdn, clientIP, option) 395 start := time.Now() 396 397 for { 398 ips, err := s.findIPsForDomain(fqdn, option) 399 if err != errRecordNotFound { 400 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) 401 return ips, err 402 } 403 404 select { 405 case <-ctx.Done(): 406 return nil, ctx.Err() 407 case <-done: 408 } 409 } 410 }