github.com/xraypb/Xray-core@v1.8.1/app/dns/nameserver_quic.go (about) 1 package dns 2 3 import ( 4 "context" 5 "net/url" 6 "sync" 7 "sync/atomic" 8 "time" 9 10 "github.com/quic-go/quic-go" 11 "github.com/xraypb/Xray-core/common" 12 "github.com/xraypb/Xray-core/common/buf" 13 "github.com/xraypb/Xray-core/common/log" 14 "github.com/xraypb/Xray-core/common/net" 15 "github.com/xraypb/Xray-core/common/protocol/dns" 16 "github.com/xraypb/Xray-core/common/session" 17 "github.com/xraypb/Xray-core/common/signal/pubsub" 18 "github.com/xraypb/Xray-core/common/task" 19 dns_feature "github.com/xraypb/Xray-core/features/dns" 20 "github.com/xraypb/Xray-core/transport/internet/tls" 21 "golang.org/x/net/dns/dnsmessage" 22 "golang.org/x/net/http2" 23 ) 24 25 // NextProtoDQ - During connection establishment, DNS/QUIC support is indicated 26 // by selecting the ALPN token "dq" in the crypto handshake. 27 const NextProtoDQ = "doq-i00" 28 29 const handshakeTimeout = time.Second * 8 30 31 // QUICNameServer implemented DNS over QUIC 32 type QUICNameServer struct { 33 sync.RWMutex 34 ips map[string]*record 35 pub *pubsub.Service 36 cleanup *task.Periodic 37 reqID uint32 38 name string 39 destination *net.Destination 40 connection quic.Connection 41 } 42 43 // NewQUICNameServer creates DNS-over-QUIC client object for local resolving 44 func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) { 45 newError("DNS: created Local DNS-over-QUIC client for ", url.String()).AtInfo().WriteToLog() 46 47 var err error 48 port := net.Port(784) 49 if url.Port() != "" { 50 port, err = net.PortFromString(url.Port()) 51 if err != nil { 52 return nil, err 53 } 54 } 55 dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port) 56 57 s := &QUICNameServer{ 58 ips: make(map[string]*record), 59 pub: pubsub.NewService(), 60 name: url.String(), 61 destination: &dest, 62 } 63 s.cleanup = &task.Periodic{ 64 Interval: time.Minute, 65 Execute: s.Cleanup, 66 } 67 68 return s, nil 69 } 70 71 // Name returns client name 72 func (s *QUICNameServer) Name() string { 73 return s.name 74 } 75 76 // Cleanup clears expired items from cache 77 func (s *QUICNameServer) Cleanup() error { 78 now := time.Now() 79 s.Lock() 80 defer s.Unlock() 81 82 if len(s.ips) == 0 { 83 return newError("nothing to do. stopping...") 84 } 85 86 for domain, record := range s.ips { 87 if record.A != nil && record.A.Expire.Before(now) { 88 record.A = nil 89 } 90 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 91 record.AAAA = nil 92 } 93 94 if record.A == nil && record.AAAA == nil { 95 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 96 delete(s.ips, domain) 97 } else { 98 s.ips[domain] = record 99 } 100 } 101 102 if len(s.ips) == 0 { 103 s.ips = make(map[string]*record) 104 } 105 106 return nil 107 } 108 109 func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 110 elapsed := time.Since(req.start) 111 112 s.Lock() 113 rec, found := s.ips[req.domain] 114 if !found { 115 rec = &record{} 116 } 117 updated := false 118 119 switch req.reqType { 120 case dnsmessage.TypeA: 121 if isNewer(rec.A, ipRec) { 122 rec.A = ipRec 123 updated = true 124 } 125 case dnsmessage.TypeAAAA: 126 addr := make([]net.Address, 0) 127 for _, ip := range ipRec.IP { 128 if len(ip.IP()) == net.IPv6len { 129 addr = append(addr, ip) 130 } 131 } 132 ipRec.IP = addr 133 if isNewer(rec.AAAA, ipRec) { 134 rec.AAAA = ipRec 135 updated = true 136 } 137 } 138 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 139 140 if updated { 141 s.ips[req.domain] = rec 142 } 143 switch req.reqType { 144 case dnsmessage.TypeA: 145 s.pub.Publish(req.domain+"4", nil) 146 case dnsmessage.TypeAAAA: 147 s.pub.Publish(req.domain+"6", nil) 148 } 149 s.Unlock() 150 common.Must(s.cleanup.Start()) 151 } 152 153 func (s *QUICNameServer) newReqID() uint16 { 154 return uint16(atomic.AddUint32(&s.reqID, 1)) 155 } 156 157 func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 158 newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx)) 159 160 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 161 162 var deadline time.Time 163 if d, ok := ctx.Deadline(); ok { 164 deadline = d 165 } else { 166 deadline = time.Now().Add(time.Second * 5) 167 } 168 169 for _, req := range reqs { 170 go func(r *dnsRequest) { 171 // generate new context for each req, using same context 172 // may cause reqs all aborted if any one encounter an error 173 dnsCtx := ctx 174 175 // reserve internal dns server requested Inbound 176 if inbound := session.InboundFromContext(ctx); inbound != nil { 177 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 178 } 179 180 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 181 Protocol: "quic", 182 SkipDNSResolve: true, 183 }) 184 185 var cancel context.CancelFunc 186 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 187 defer cancel() 188 189 b, err := dns.PackMessage(r.msg) 190 if err != nil { 191 newError("failed to pack dns query").Base(err).AtError().WriteToLog() 192 return 193 } 194 195 conn, err := s.openStream(dnsCtx) 196 if err != nil { 197 newError("failed to open quic connection").Base(err).AtError().WriteToLog() 198 return 199 } 200 201 _, err = conn.Write(b.Bytes()) 202 if err != nil { 203 newError("failed to send query").Base(err).AtError().WriteToLog() 204 return 205 } 206 207 _ = conn.Close() 208 209 respBuf := buf.New() 210 defer respBuf.Release() 211 n, err := respBuf.ReadFrom(conn) 212 if err != nil && n == 0 { 213 newError("failed to read response").Base(err).AtError().WriteToLog() 214 return 215 } 216 217 rec, err := parseResponse(respBuf.Bytes()) 218 if err != nil { 219 newError("failed to handle response").Base(err).AtError().WriteToLog() 220 return 221 } 222 s.updateIP(r, rec) 223 }(req) 224 } 225 } 226 227 func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 228 s.RLock() 229 record, found := s.ips[domain] 230 s.RUnlock() 231 232 if !found { 233 return nil, errRecordNotFound 234 } 235 236 var err4 error 237 var err6 error 238 var ips []net.Address 239 var ip6 []net.Address 240 241 if option.IPv4Enable { 242 ips, err4 = record.A.getIPs() 243 } 244 245 if option.IPv6Enable { 246 ip6, err6 = record.AAAA.getIPs() 247 ips = append(ips, ip6...) 248 } 249 250 if len(ips) > 0 { 251 return toNetIP(ips) 252 } 253 254 if err4 != nil { 255 return nil, err4 256 } 257 258 if err6 != nil { 259 return nil, err6 260 } 261 262 if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { 263 return nil, dns_feature.ErrEmptyResponse 264 } 265 266 return nil, errRecordNotFound 267 } 268 269 // QueryIP is called from dns.Server->queryIPTimeout 270 func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { 271 fqdn := Fqdn(domain) 272 273 if disableCache { 274 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 275 } else { 276 ips, err := s.findIPsForDomain(fqdn, option) 277 if err != errRecordNotFound { 278 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 279 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) 280 return ips, err 281 } 282 } 283 284 // ipv4 and ipv6 belong to different subscription groups 285 var sub4, sub6 *pubsub.Subscriber 286 if option.IPv4Enable { 287 sub4 = s.pub.Subscribe(fqdn + "4") 288 defer sub4.Close() 289 } 290 if option.IPv6Enable { 291 sub6 = s.pub.Subscribe(fqdn + "6") 292 defer sub6.Close() 293 } 294 done := make(chan interface{}) 295 go func() { 296 if sub4 != nil { 297 select { 298 case <-sub4.Wait(): 299 case <-ctx.Done(): 300 } 301 } 302 if sub6 != nil { 303 select { 304 case <-sub6.Wait(): 305 case <-ctx.Done(): 306 } 307 } 308 close(done) 309 }() 310 s.sendQuery(ctx, fqdn, clientIP, option) 311 start := time.Now() 312 313 for { 314 ips, err := s.findIPsForDomain(fqdn, option) 315 if err != errRecordNotFound { 316 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) 317 return ips, err 318 } 319 320 select { 321 case <-ctx.Done(): 322 return nil, ctx.Err() 323 case <-done: 324 } 325 } 326 } 327 328 func isActive(s quic.Connection) bool { 329 select { 330 case <-s.Context().Done(): 331 return false 332 default: 333 return true 334 } 335 } 336 337 func (s *QUICNameServer) getConnection() (quic.Connection, error) { 338 var conn quic.Connection 339 s.RLock() 340 conn = s.connection 341 if conn != nil && isActive(conn) { 342 s.RUnlock() 343 return conn, nil 344 } 345 if conn != nil { 346 // we're recreating the connection, let's create a new one 347 _ = conn.CloseWithError(0, "") 348 } 349 s.RUnlock() 350 351 s.Lock() 352 defer s.Unlock() 353 354 var err error 355 conn, err = s.openConnection() 356 if err != nil { 357 // This does not look too nice, but QUIC (or maybe quic-go) 358 // doesn't seem stable enough. 359 // Maybe retransmissions aren't fully implemented in quic-go? 360 // Anyways, the simple solution is to make a second try when 361 // it fails to open the QUIC connection. 362 conn, err = s.openConnection() 363 if err != nil { 364 return nil, err 365 } 366 } 367 s.connection = conn 368 return conn, nil 369 } 370 371 func (s *QUICNameServer) openConnection() (quic.Connection, error) { 372 tlsConfig := tls.Config{} 373 quicConfig := &quic.Config{ 374 HandshakeIdleTimeout: handshakeTimeout, 375 } 376 377 conn, err := quic.DialAddrContext(context.Background(), s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto("http/1.1", http2.NextProtoTLS, NextProtoDQ)), quicConfig) 378 log.Record(&log.AccessMessage{ 379 From: "DNS", 380 To: s.destination, 381 Status: log.AccessAccepted, 382 Detour: "local", 383 }) 384 if err != nil { 385 return nil, err 386 } 387 388 return conn, nil 389 } 390 391 func (s *QUICNameServer) openStream(ctx context.Context) (quic.Stream, error) { 392 conn, err := s.getConnection() 393 if err != nil { 394 return nil, err 395 } 396 397 // open a new stream 398 return conn.OpenStreamSync(ctx) 399 }