github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/app/dns/nameserver_doh.go (about) 1 package dns 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "net/http" 9 "net/url" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/xmplusdev/xmcore/common" 15 "github.com/xmplusdev/xmcore/common/log" 16 "github.com/xmplusdev/xmcore/common/net" 17 "github.com/xmplusdev/xmcore/common/net/cnc" 18 "github.com/xmplusdev/xmcore/common/protocol/dns" 19 "github.com/xmplusdev/xmcore/common/session" 20 "github.com/xmplusdev/xmcore/common/signal/pubsub" 21 "github.com/xmplusdev/xmcore/common/task" 22 dns_feature "github.com/xmplusdev/xmcore/features/dns" 23 "github.com/xmplusdev/xmcore/features/routing" 24 "github.com/xmplusdev/xmcore/transport/internet" 25 "golang.org/x/net/dns/dnsmessage" 26 ) 27 28 // DoHNameServer implemented DNS over HTTPS (RFC8484) Wire Format, 29 // which is compatible with traditional dns over udp(RFC1035), 30 // thus most of the DOH implementation is copied from udpns.go 31 type DoHNameServer struct { 32 dispatcher routing.Dispatcher 33 sync.RWMutex 34 ips map[string]*record 35 pub *pubsub.Service 36 cleanup *task.Periodic 37 reqID uint32 38 httpClient *http.Client 39 dohURL string 40 name string 41 queryStrategy QueryStrategy 42 } 43 44 // NewDoHNameServer creates DOH server object for remote resolving. 45 func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (*DoHNameServer, error) { 46 newError("DNS: created Remote DOH client for ", url.String()).AtInfo().WriteToLog() 47 s := baseDOHNameServer(url, "DOH", queryStrategy) 48 49 s.dispatcher = dispatcher 50 tr := &http.Transport{ 51 MaxIdleConns: 30, 52 IdleConnTimeout: 90 * time.Second, 53 TLSHandshakeTimeout: 30 * time.Second, 54 ForceAttemptHTTP2: true, 55 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 56 dest, err := net.ParseDestination(network + ":" + addr) 57 if err != nil { 58 return nil, err 59 } 60 link, err := s.dispatcher.Dispatch(toDnsContext(ctx, s.dohURL), dest) 61 select { 62 case <-ctx.Done(): 63 return nil, ctx.Err() 64 default: 65 66 } 67 if err != nil { 68 return nil, err 69 } 70 71 cc := common.ChainedClosable{} 72 if cw, ok := link.Writer.(common.Closable); ok { 73 cc = append(cc, cw) 74 } 75 if cr, ok := link.Reader.(common.Closable); ok { 76 cc = append(cc, cr) 77 } 78 return cnc.NewConnection( 79 cnc.ConnectionInputMulti(link.Writer), 80 cnc.ConnectionOutputMulti(link.Reader), 81 cnc.ConnectionOnClose(cc), 82 ), nil 83 }, 84 } 85 s.httpClient = &http.Client{ 86 Timeout: time.Second * 180, 87 Transport: tr, 88 } 89 90 return s, nil 91 } 92 93 // NewDoHLocalNameServer creates DOH client object for local resolving 94 func NewDoHLocalNameServer(url *url.URL, queryStrategy QueryStrategy) *DoHNameServer { 95 url.Scheme = "https" 96 s := baseDOHNameServer(url, "DOHL", queryStrategy) 97 tr := &http.Transport{ 98 IdleConnTimeout: 90 * time.Second, 99 ForceAttemptHTTP2: true, 100 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 101 dest, err := net.ParseDestination(network + ":" + addr) 102 if err != nil { 103 return nil, err 104 } 105 conn, err := internet.DialSystem(ctx, dest, nil) 106 log.Record(&log.AccessMessage{ 107 From: "DNS", 108 To: s.dohURL, 109 Status: log.AccessAccepted, 110 Detour: "local", 111 }) 112 if err != nil { 113 return nil, err 114 } 115 return conn, nil 116 }, 117 } 118 s.httpClient = &http.Client{ 119 Timeout: time.Second * 180, 120 Transport: tr, 121 } 122 newError("DNS: created Local DOH client for ", url.String()).AtInfo().WriteToLog() 123 return s 124 } 125 126 func baseDOHNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy) *DoHNameServer { 127 s := &DoHNameServer{ 128 ips: make(map[string]*record), 129 pub: pubsub.NewService(), 130 name: prefix + "//" + url.Host, 131 dohURL: url.String(), 132 queryStrategy: queryStrategy, 133 } 134 s.cleanup = &task.Periodic{ 135 Interval: time.Minute, 136 Execute: s.Cleanup, 137 } 138 return s 139 } 140 141 // Name implements Server. 142 func (s *DoHNameServer) Name() string { 143 return s.name 144 } 145 146 // Cleanup clears expired items from cache 147 func (s *DoHNameServer) Cleanup() error { 148 now := time.Now() 149 s.Lock() 150 defer s.Unlock() 151 152 if len(s.ips) == 0 { 153 return newError("nothing to do. stopping...") 154 } 155 156 for domain, record := range s.ips { 157 if record.A != nil && record.A.Expire.Before(now) { 158 record.A = nil 159 } 160 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 161 record.AAAA = nil 162 } 163 164 if record.A == nil && record.AAAA == nil { 165 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 166 delete(s.ips, domain) 167 } else { 168 s.ips[domain] = record 169 } 170 } 171 172 if len(s.ips) == 0 { 173 s.ips = make(map[string]*record) 174 } 175 176 return nil 177 } 178 179 func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 180 elapsed := time.Since(req.start) 181 182 s.Lock() 183 rec, found := s.ips[req.domain] 184 if !found { 185 rec = &record{} 186 } 187 updated := false 188 189 switch req.reqType { 190 case dnsmessage.TypeA: 191 if isNewer(rec.A, ipRec) { 192 rec.A = ipRec 193 updated = true 194 } 195 case dnsmessage.TypeAAAA: 196 addr := make([]net.Address, 0, len(ipRec.IP)) 197 for _, ip := range ipRec.IP { 198 if len(ip.IP()) == net.IPv6len { 199 addr = append(addr, ip) 200 } 201 } 202 ipRec.IP = addr 203 if isNewer(rec.AAAA, ipRec) { 204 rec.AAAA = ipRec 205 updated = true 206 } 207 } 208 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 209 210 if updated { 211 s.ips[req.domain] = rec 212 } 213 switch req.reqType { 214 case dnsmessage.TypeA: 215 s.pub.Publish(req.domain+"4", nil) 216 case dnsmessage.TypeAAAA: 217 s.pub.Publish(req.domain+"6", nil) 218 } 219 s.Unlock() 220 common.Must(s.cleanup.Start()) 221 } 222 223 func (s *DoHNameServer) newReqID() uint16 { 224 return uint16(atomic.AddUint32(&s.reqID, 1)) 225 } 226 227 func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 228 newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx)) 229 230 if s.name+"." == "DOH//"+domain { 231 newError(s.name, " tries to resolve itself! Use IP or set \"hosts\" instead.").AtError().WriteToLog(session.ExportIDToError(ctx)) 232 return 233 } 234 235 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 236 237 var deadline time.Time 238 if d, ok := ctx.Deadline(); ok { 239 deadline = d 240 } else { 241 deadline = time.Now().Add(time.Second * 5) 242 } 243 244 for _, req := range reqs { 245 go func(r *dnsRequest) { 246 // generate new context for each req, using same context 247 // may cause reqs all aborted if any one encounter an error 248 dnsCtx := ctx 249 250 // reserve internal dns server requested Inbound 251 if inbound := session.InboundFromContext(ctx); inbound != nil { 252 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 253 } 254 255 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 256 Protocol: "https", 257 SkipDNSResolve: true, 258 }) 259 260 // forced to use mux for DOH 261 // dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true) 262 263 var cancel context.CancelFunc 264 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 265 defer cancel() 266 267 b, err := dns.PackMessage(r.msg) 268 if err != nil { 269 newError("failed to pack dns query for ", domain).Base(err).AtError().WriteToLog() 270 return 271 } 272 resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes()) 273 if err != nil { 274 newError("failed to retrieve response for ", domain).Base(err).AtError().WriteToLog() 275 return 276 } 277 rec, err := parseResponse(resp) 278 if err != nil { 279 newError("failed to handle DOH response for ", domain).Base(err).AtError().WriteToLog() 280 return 281 } 282 s.updateIP(r, rec) 283 }(req) 284 } 285 } 286 287 func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) { 288 body := bytes.NewBuffer(b) 289 req, err := http.NewRequest("POST", s.dohURL, body) 290 if err != nil { 291 return nil, err 292 } 293 294 req.Header.Add("Accept", "application/dns-message") 295 req.Header.Add("Content-Type", "application/dns-message") 296 297 hc := s.httpClient 298 299 resp, err := hc.Do(req.WithContext(ctx)) 300 if err != nil { 301 return nil, err 302 } 303 304 defer resp.Body.Close() 305 if resp.StatusCode != http.StatusOK { 306 io.Copy(io.Discard, resp.Body) // flush resp.Body so that the conn is reusable 307 return nil, fmt.Errorf("DOH server returned code %d", resp.StatusCode) 308 } 309 310 return io.ReadAll(resp.Body) 311 } 312 313 func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 314 s.RLock() 315 record, found := s.ips[domain] 316 s.RUnlock() 317 318 if !found { 319 return nil, errRecordNotFound 320 } 321 322 var err4 error 323 var err6 error 324 var ips []net.Address 325 var ip6 []net.Address 326 327 if option.IPv4Enable { 328 ips, err4 = record.A.getIPs() 329 } 330 331 if option.IPv6Enable { 332 ip6, err6 = record.AAAA.getIPs() 333 ips = append(ips, ip6...) 334 } 335 336 if len(ips) > 0 { 337 return toNetIP(ips) 338 } 339 340 if err4 != nil { 341 return nil, err4 342 } 343 344 if err6 != nil { 345 return nil, err6 346 } 347 348 if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { 349 return nil, dns_feature.ErrEmptyResponse 350 } 351 352 return nil, errRecordNotFound 353 } 354 355 // QueryIP implements Server. 356 func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { // nolint: dupl 357 fqdn := Fqdn(domain) 358 option = ResolveIpOptionOverride(s.queryStrategy, option) 359 if !option.IPv4Enable && !option.IPv6Enable { 360 return nil, dns_feature.ErrEmptyResponse 361 } 362 363 if disableCache { 364 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 365 } else { 366 ips, err := s.findIPsForDomain(fqdn, option) 367 if err != errRecordNotFound { 368 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 369 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) 370 return ips, err 371 } 372 } 373 374 // ipv4 and ipv6 belong to different subscription groups 375 var sub4, sub6 *pubsub.Subscriber 376 if option.IPv4Enable { 377 sub4 = s.pub.Subscribe(fqdn + "4") 378 defer sub4.Close() 379 } 380 if option.IPv6Enable { 381 sub6 = s.pub.Subscribe(fqdn + "6") 382 defer sub6.Close() 383 } 384 done := make(chan interface{}) 385 go func() { 386 if sub4 != nil { 387 select { 388 case <-sub4.Wait(): 389 case <-ctx.Done(): 390 } 391 } 392 if sub6 != nil { 393 select { 394 case <-sub6.Wait(): 395 case <-ctx.Done(): 396 } 397 } 398 close(done) 399 }() 400 s.sendQuery(ctx, fqdn, clientIP, option) 401 start := time.Now() 402 403 for { 404 ips, err := s.findIPsForDomain(fqdn, option) 405 if err != errRecordNotFound { 406 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) 407 return ips, err 408 } 409 410 select { 411 case <-ctx.Done(): 412 return nil, ctx.Err() 413 case <-done: 414 } 415 } 416 }