github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/app/dns/nameserver_quic.go (about) 1 package dns 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/binary" 7 "net/url" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/quic-go/quic-go" 13 "golang.org/x/net/dns/dnsmessage" 14 15 "github.com/v2fly/v2ray-core/v5/common" 16 "github.com/v2fly/v2ray-core/v5/common/buf" 17 "github.com/v2fly/v2ray-core/v5/common/net" 18 "github.com/v2fly/v2ray-core/v5/common/protocol/dns" 19 "github.com/v2fly/v2ray-core/v5/common/session" 20 "github.com/v2fly/v2ray-core/v5/common/signal/pubsub" 21 "github.com/v2fly/v2ray-core/v5/common/task" 22 dns_feature "github.com/v2fly/v2ray-core/v5/features/dns" 23 "github.com/v2fly/v2ray-core/v5/transport/internet/tls" 24 ) 25 26 // NextProtoDQ - During connection establishment, DNS/QUIC support is indicated 27 // by selecting the ALPN token "doq" in the crypto handshake. 28 const NextProtoDQ = "doq" 29 30 const handshakeIdleTimeout = time.Second * 8 31 32 // QUICNameServer implemented DNS over QUIC 33 type QUICNameServer struct { 34 sync.RWMutex 35 ips map[string]record 36 pub *pubsub.Service 37 cleanup *task.Periodic 38 reqID uint32 39 name string 40 destination net.Destination 41 connection quic.Connection 42 } 43 44 // NewQUICNameServer creates DNS-over-QUIC client object for local resolving 45 func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) { 46 newError("DNS: created Local DNS-over-QUIC client for ", url.String()).AtInfo().WriteToLog() 47 48 var err error 49 port := net.Port(853) 50 if url.Port() != "" { 51 port, err = net.PortFromString(url.Port()) 52 if err != nil { 53 return nil, err 54 } 55 } 56 dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port) 57 58 s := &QUICNameServer{ 59 ips: make(map[string]record), 60 pub: pubsub.NewService(), 61 name: url.String(), 62 destination: dest, 63 } 64 s.cleanup = &task.Periodic{ 65 Interval: time.Minute, 66 Execute: s.Cleanup, 67 } 68 69 return s, nil 70 } 71 72 // Name returns client name 73 func (s *QUICNameServer) Name() string { 74 return s.name 75 } 76 77 // Cleanup clears expired items from cache 78 func (s *QUICNameServer) Cleanup() error { 79 now := time.Now() 80 s.Lock() 81 defer s.Unlock() 82 83 if len(s.ips) == 0 { 84 return newError("nothing to do. stopping...") 85 } 86 87 for domain, record := range s.ips { 88 if record.A != nil && record.A.Expire.Before(now) { 89 record.A = nil 90 } 91 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 92 record.AAAA = nil 93 } 94 95 if record.A == nil && record.AAAA == nil { 96 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 97 delete(s.ips, domain) 98 } else { 99 s.ips[domain] = record 100 } 101 } 102 103 if len(s.ips) == 0 { 104 s.ips = make(map[string]record) 105 } 106 107 return nil 108 } 109 110 func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 111 elapsed := time.Since(req.start) 112 113 s.Lock() 114 rec := s.ips[req.domain] 115 updated := false 116 117 switch req.reqType { 118 case dnsmessage.TypeA: 119 if isNewer(rec.A, ipRec) { 120 rec.A = ipRec 121 updated = true 122 } 123 case dnsmessage.TypeAAAA: 124 addr := make([]net.Address, 0) 125 for _, ip := range ipRec.IP { 126 if len(ip.IP()) == net.IPv6len { 127 addr = append(addr, ip) 128 } 129 } 130 ipRec.IP = addr 131 if isNewer(rec.AAAA, ipRec) { 132 rec.AAAA = ipRec 133 updated = true 134 } 135 } 136 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 137 138 if updated { 139 s.ips[req.domain] = rec 140 } 141 switch req.reqType { 142 case dnsmessage.TypeA: 143 s.pub.Publish(req.domain+"4", nil) 144 case dnsmessage.TypeAAAA: 145 s.pub.Publish(req.domain+"6", nil) 146 } 147 s.Unlock() 148 common.Must(s.cleanup.Start()) 149 } 150 151 func (s *QUICNameServer) newReqID() uint16 { 152 return uint16(atomic.AddUint32(&s.reqID, 1)) 153 } 154 155 func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 156 newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx)) 157 158 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 159 160 var deadline time.Time 161 if d, ok := ctx.Deadline(); ok { 162 deadline = d 163 } else { 164 deadline = time.Now().Add(time.Second * 5) 165 } 166 167 for _, req := range reqs { 168 go func(r *dnsRequest) { 169 // generate new context for each req, using same context 170 // may cause reqs all aborted if any one encounter an error 171 dnsCtx := ctx 172 173 // reserve internal dns server requested Inbound 174 if inbound := session.InboundFromContext(ctx); inbound != nil { 175 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 176 } 177 178 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 179 Protocol: "quic", 180 SkipDNSResolve: true, 181 }) 182 183 var cancel context.CancelFunc 184 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 185 defer cancel() 186 187 b, err := dns.PackMessage(r.msg) 188 if err != nil { 189 newError("failed to pack dns query").Base(err).AtError().WriteToLog() 190 return 191 } 192 193 dnsReqBuf := buf.New() 194 binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len())) 195 dnsReqBuf.Write(b.Bytes()) 196 b.Release() 197 198 conn, err := s.openStream(dnsCtx) 199 if err != nil { 200 newError("failed to open quic connection").Base(err).AtError().WriteToLog() 201 return 202 } 203 204 _, err = conn.Write(dnsReqBuf.Bytes()) 205 if err != nil { 206 newError("failed to send query").Base(err).AtError().WriteToLog() 207 return 208 } 209 210 _ = conn.Close() 211 212 respBuf := buf.New() 213 defer respBuf.Release() 214 n, err := respBuf.ReadFullFrom(conn, 2) 215 if err != nil && n == 0 { 216 newError("failed to read response length").Base(err).AtError().WriteToLog() 217 return 218 } 219 var length int16 220 err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length) 221 if err != nil { 222 newError("failed to parse response length").Base(err).AtError().WriteToLog() 223 return 224 } 225 respBuf.Clear() 226 n, err = respBuf.ReadFullFrom(conn, int32(length)) 227 if err != nil && n == 0 { 228 newError("failed to read response length").Base(err).AtError().WriteToLog() 229 return 230 } 231 232 rec, err := parseResponse(respBuf.Bytes()) 233 if err != nil { 234 newError("failed to handle response").Base(err).AtError().WriteToLog() 235 return 236 } 237 s.updateIP(r, rec) 238 }(req) 239 } 240 } 241 242 func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 243 s.RLock() 244 record, found := s.ips[domain] 245 s.RUnlock() 246 247 if !found { 248 return nil, errRecordNotFound 249 } 250 251 var ips []net.Address 252 var lastErr error 253 if option.IPv4Enable { 254 a, err := record.A.getIPs() 255 if err != nil { 256 lastErr = err 257 } 258 ips = append(ips, a...) 259 } 260 261 if option.IPv6Enable { 262 aaaa, err := record.AAAA.getIPs() 263 if err != nil { 264 lastErr = err 265 } 266 ips = append(ips, aaaa...) 267 } 268 269 if len(ips) > 0 { 270 return toNetIP(ips) 271 } 272 273 if lastErr != nil { 274 return nil, lastErr 275 } 276 277 return nil, dns_feature.ErrEmptyResponse 278 } 279 280 // QueryIP is called from dns.Server->queryIPTimeout 281 func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { 282 fqdn := Fqdn(domain) 283 284 if disableCache { 285 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 286 } else { 287 ips, err := s.findIPsForDomain(fqdn, option) 288 if err != errRecordNotFound { 289 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 290 return ips, err 291 } 292 } 293 294 // ipv4 and ipv6 belong to different subscription groups 295 var sub4, sub6 *pubsub.Subscriber 296 if option.IPv4Enable { 297 sub4 = s.pub.Subscribe(fqdn + "4") 298 defer sub4.Close() 299 } 300 if option.IPv6Enable { 301 sub6 = s.pub.Subscribe(fqdn + "6") 302 defer sub6.Close() 303 } 304 done := make(chan interface{}) 305 go func() { 306 if sub4 != nil { 307 select { 308 case <-sub4.Wait(): 309 case <-ctx.Done(): 310 } 311 } 312 if sub6 != nil { 313 select { 314 case <-sub6.Wait(): 315 case <-ctx.Done(): 316 } 317 } 318 close(done) 319 }() 320 s.sendQuery(ctx, fqdn, clientIP, option) 321 322 for { 323 ips, err := s.findIPsForDomain(fqdn, option) 324 if err != errRecordNotFound { 325 return ips, err 326 } 327 328 select { 329 case <-ctx.Done(): 330 return nil, ctx.Err() 331 case <-done: 332 } 333 } 334 } 335 336 func isActive(s quic.Connection) bool { 337 select { 338 case <-s.Context().Done(): 339 return false 340 default: 341 return true 342 } 343 } 344 345 func (s *QUICNameServer) getConnection(ctx context.Context) (quic.Connection, error) { 346 var conn quic.Connection 347 s.RLock() 348 conn = s.connection 349 if conn != nil && isActive(conn) { 350 s.RUnlock() 351 return conn, nil 352 } 353 if conn != nil { 354 // we're recreating the connection, let's create a new one 355 _ = conn.CloseWithError(0, "") 356 } 357 s.RUnlock() 358 359 s.Lock() 360 defer s.Unlock() 361 362 var err error 363 conn, err = s.openConnection(ctx) 364 if err != nil { 365 // This does not look too nice, but QUIC (or maybe quic-go) 366 // doesn't seem stable enough. 367 // Maybe retransmissions aren't fully implemented in quic-go? 368 // Anyways, the simple solution is to make a second try when 369 // it fails to open the QUIC connection. 370 conn, err = s.openConnection(ctx) 371 if err != nil { 372 return nil, err 373 } 374 } 375 s.connection = conn 376 return conn, nil 377 } 378 379 func (s *QUICNameServer) openConnection(ctx context.Context) (quic.Connection, error) { 380 tlsConfig := tls.Config{ 381 ServerName: func() string { 382 switch s.destination.Address.Family() { 383 case net.AddressFamilyIPv4, net.AddressFamilyIPv6: 384 return s.destination.Address.IP().String() 385 case net.AddressFamilyDomain: 386 return s.destination.Address.Domain() 387 default: 388 panic("unknown address family") 389 } 390 }(), 391 } 392 quicConfig := &quic.Config{ 393 HandshakeIdleTimeout: handshakeIdleTimeout, 394 } 395 396 conn, err := quic.DialAddr(ctx, s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto(NextProtoDQ)), quicConfig) 397 if err != nil { 398 return nil, err 399 } 400 401 return conn, nil 402 } 403 404 func (s *QUICNameServer) openStream(ctx context.Context) (quic.Stream, error) { 405 conn, err := s.getConnection(ctx) 406 if err != nil { 407 return nil, err 408 } 409 410 // open a new stream 411 return conn.OpenStreamSync(ctx) 412 }