github.com/imannamdari/v2ray-core/v5@v5.0.5/app/dns/nameserver_doh.go (about) 1 //go:build !confonly 2 // +build !confonly 3 4 package dns 5 6 import ( 7 "bytes" 8 "context" 9 "fmt" 10 "io" 11 "net/http" 12 "net/url" 13 "sync" 14 "sync/atomic" 15 "time" 16 17 "golang.org/x/net/dns/dnsmessage" 18 19 "github.com/imannamdari/v2ray-core/v5/common" 20 "github.com/imannamdari/v2ray-core/v5/common/net" 21 "github.com/imannamdari/v2ray-core/v5/common/protocol/dns" 22 "github.com/imannamdari/v2ray-core/v5/common/session" 23 "github.com/imannamdari/v2ray-core/v5/common/signal/pubsub" 24 "github.com/imannamdari/v2ray-core/v5/common/task" 25 dns_feature "github.com/imannamdari/v2ray-core/v5/features/dns" 26 "github.com/imannamdari/v2ray-core/v5/features/routing" 27 "github.com/imannamdari/v2ray-core/v5/transport/internet" 28 ) 29 30 // DoHNameServer implemented DNS over HTTPS (RFC8484) Wire Format, 31 // which is compatible with traditional dns over udp(RFC1035), 32 // thus most of the DOH implementation is copied from udpns.go 33 type DoHNameServer struct { 34 sync.RWMutex 35 ips map[string]record 36 pub *pubsub.Service 37 cleanup *task.Periodic 38 reqID uint32 39 httpClient *http.Client 40 dohURL string 41 name string 42 } 43 44 // NewDoHNameServer creates DOH server object for remote resolving. 45 func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher) (*DoHNameServer, error) { 46 newError("DNS: created Remote DOH client for ", url.String()).AtInfo().WriteToLog() 47 s := baseDOHNameServer(url, "DOH") 48 49 // Dispatched connection will be closed (interrupted) after each request 50 // This makes DOH inefficient without a keep-alived connection 51 // See: core/app/proxyman/outbound/handler.go:113 52 // Using mux (https request wrapped in a stream layer) improves the situation. 53 // Recommend to use NewDoHLocalNameServer (DOHL:) if v2ray instance is running on 54 // a normal network eg. the server side of v2ray 55 tr := &http.Transport{ 56 MaxIdleConns: 30, 57 IdleConnTimeout: 90 * time.Second, 58 TLSHandshakeTimeout: 30 * time.Second, 59 ForceAttemptHTTP2: true, 60 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 61 dest, err := net.ParseDestination(network + ":" + addr) 62 if err != nil { 63 return nil, err 64 } 65 66 link, err := dispatcher.Dispatch(ctx, dest) 67 if err != nil { 68 return nil, err 69 } 70 return net.NewConnection( 71 net.ConnectionInputMulti(link.Writer), 72 net.ConnectionOutputMulti(link.Reader), 73 ), nil 74 }, 75 } 76 77 dispatchedClient := &http.Client{ 78 Transport: tr, 79 Timeout: 60 * time.Second, 80 } 81 82 s.httpClient = dispatchedClient 83 return s, nil 84 } 85 86 // NewDoHLocalNameServer creates DOH client object for local resolving 87 func NewDoHLocalNameServer(url *url.URL) *DoHNameServer { 88 url.Scheme = "https" 89 s := baseDOHNameServer(url, "DOHL") 90 tr := &http.Transport{ 91 IdleConnTimeout: 90 * time.Second, 92 ForceAttemptHTTP2: true, 93 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 94 dest, err := net.ParseDestination(network + ":" + addr) 95 if err != nil { 96 return nil, err 97 } 98 conn, err := internet.DialSystem(ctx, dest, nil) 99 if err != nil { 100 return nil, err 101 } 102 return conn, nil 103 }, 104 } 105 s.httpClient = &http.Client{ 106 Timeout: time.Second * 180, 107 Transport: tr, 108 } 109 newError("DNS: created Local DOH client for ", url.String()).AtInfo().WriteToLog() 110 return s 111 } 112 113 func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer { 114 s := &DoHNameServer{ 115 ips: make(map[string]record), 116 pub: pubsub.NewService(), 117 name: prefix + "//" + url.Host, 118 dohURL: url.String(), 119 } 120 s.cleanup = &task.Periodic{ 121 Interval: time.Minute, 122 Execute: s.Cleanup, 123 } 124 return s 125 } 126 127 // Name implements Server. 128 func (s *DoHNameServer) Name() string { 129 return s.name 130 } 131 132 // Cleanup clears expired items from cache 133 func (s *DoHNameServer) Cleanup() error { 134 now := time.Now() 135 s.Lock() 136 defer s.Unlock() 137 138 if len(s.ips) == 0 { 139 return newError("nothing to do. stopping...") 140 } 141 142 for domain, record := range s.ips { 143 if record.A != nil && record.A.Expire.Before(now) { 144 record.A = nil 145 } 146 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 147 record.AAAA = nil 148 } 149 150 if record.A == nil && record.AAAA == nil { 151 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 152 delete(s.ips, domain) 153 } else { 154 s.ips[domain] = record 155 } 156 } 157 158 if len(s.ips) == 0 { 159 s.ips = make(map[string]record) 160 } 161 162 return nil 163 } 164 165 func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 166 elapsed := time.Since(req.start) 167 168 s.Lock() 169 rec := s.ips[req.domain] 170 updated := false 171 172 switch req.reqType { 173 case dnsmessage.TypeA: 174 if isNewer(rec.A, ipRec) { 175 rec.A = ipRec 176 updated = true 177 } 178 case dnsmessage.TypeAAAA: 179 addr := make([]net.Address, 0) 180 for _, ip := range ipRec.IP { 181 if len(ip.IP()) == net.IPv6len { 182 addr = append(addr, ip) 183 } 184 } 185 ipRec.IP = addr 186 if isNewer(rec.AAAA, ipRec) { 187 rec.AAAA = ipRec 188 updated = true 189 } 190 } 191 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 192 193 if updated { 194 s.ips[req.domain] = rec 195 } 196 switch req.reqType { 197 case dnsmessage.TypeA: 198 s.pub.Publish(req.domain+"4", nil) 199 case dnsmessage.TypeAAAA: 200 s.pub.Publish(req.domain+"6", nil) 201 } 202 s.Unlock() 203 common.Must(s.cleanup.Start()) 204 } 205 206 func (s *DoHNameServer) newReqID() uint16 { 207 return uint16(atomic.AddUint32(&s.reqID, 1)) 208 } 209 210 func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 211 newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx)) 212 213 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 214 215 var deadline time.Time 216 if d, ok := ctx.Deadline(); ok { 217 deadline = d 218 } else { 219 deadline = time.Now().Add(time.Second * 5) 220 } 221 222 for _, req := range reqs { 223 go func(r *dnsRequest) { 224 // generate new context for each req, using same context 225 // may cause reqs all aborted if any one encounter an error 226 dnsCtx := ctx 227 228 // reserve internal dns server requested Inbound 229 if inbound := session.InboundFromContext(ctx); inbound != nil { 230 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 231 } 232 233 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 234 Protocol: "https", 235 SkipDNSResolve: true, 236 }) 237 238 // forced to use mux for DOH 239 dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true) 240 241 var cancel context.CancelFunc 242 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 243 defer cancel() 244 245 b, err := dns.PackMessage(r.msg) 246 if err != nil { 247 newError("failed to pack dns query").Base(err).AtError().WriteToLog() 248 return 249 } 250 resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes()) 251 if err != nil { 252 newError("failed to retrieve response").Base(err).AtError().WriteToLog() 253 return 254 } 255 rec, err := parseResponse(resp) 256 if err != nil { 257 newError("failed to handle DOH response").Base(err).AtError().WriteToLog() 258 return 259 } 260 s.updateIP(r, rec) 261 }(req) 262 } 263 } 264 265 func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) { 266 body := bytes.NewBuffer(b) 267 req, err := http.NewRequest("POST", s.dohURL, body) 268 if err != nil { 269 return nil, err 270 } 271 272 req.Header.Add("Accept", "application/dns-message") 273 req.Header.Add("Content-Type", "application/dns-message") 274 275 resp, err := s.httpClient.Do(req.WithContext(ctx)) 276 if err != nil { 277 return nil, err 278 } 279 280 defer resp.Body.Close() 281 if resp.StatusCode != http.StatusOK { 282 io.Copy(io.Discard, resp.Body) // flush resp.Body so that the conn is reusable 283 return nil, fmt.Errorf("DOH server returned code %d", resp.StatusCode) 284 } 285 286 return io.ReadAll(resp.Body) 287 } 288 289 func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 290 s.RLock() 291 record, found := s.ips[domain] 292 s.RUnlock() 293 294 if !found { 295 return nil, errRecordNotFound 296 } 297 298 var ips []net.Address 299 var lastErr error 300 if option.IPv4Enable { 301 a, err := record.A.getIPs() 302 if err != nil { 303 lastErr = err 304 } 305 ips = append(ips, a...) 306 } 307 308 if option.IPv6Enable { 309 aaaa, err := record.AAAA.getIPs() 310 if err != nil { 311 lastErr = err 312 } 313 ips = append(ips, aaaa...) 314 } 315 316 if len(ips) > 0 { 317 return toNetIP(ips) 318 } 319 320 if lastErr != nil { 321 return nil, lastErr 322 } 323 324 return nil, dns_feature.ErrEmptyResponse 325 } 326 327 // QueryIP implements Server. 328 func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { // nolint: dupl 329 fqdn := Fqdn(domain) 330 331 if disableCache { 332 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 333 } else { 334 ips, err := s.findIPsForDomain(fqdn, option) 335 if err != errRecordNotFound { 336 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 337 return ips, err 338 } 339 } 340 341 // ipv4 and ipv6 belong to different subscription groups 342 var sub4, sub6 *pubsub.Subscriber 343 if option.IPv4Enable { 344 sub4 = s.pub.Subscribe(fqdn + "4") 345 defer sub4.Close() 346 } 347 if option.IPv6Enable { 348 sub6 = s.pub.Subscribe(fqdn + "6") 349 defer sub6.Close() 350 } 351 done := make(chan interface{}) 352 go func() { 353 if sub4 != nil { 354 select { 355 case <-sub4.Wait(): 356 case <-ctx.Done(): 357 } 358 } 359 if sub6 != nil { 360 select { 361 case <-sub6.Wait(): 362 case <-ctx.Done(): 363 } 364 } 365 close(done) 366 }() 367 s.sendQuery(ctx, fqdn, clientIP, option) 368 369 for { 370 ips, err := s.findIPsForDomain(fqdn, option) 371 if err != errRecordNotFound { 372 return ips, err 373 } 374 375 select { 376 case <-ctx.Done(): 377 return nil, ctx.Err() 378 case <-done: 379 } 380 } 381 }