github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/app/dns/nameserver_quic.go (about) 1 package dns 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/binary" 7 "net/url" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/quic-go/quic-go" 13 "github.com/xmplusdev/xmcore/common" 14 "github.com/xmplusdev/xmcore/common/buf" 15 "github.com/xmplusdev/xmcore/common/log" 16 "github.com/xmplusdev/xmcore/common/net" 17 "github.com/xmplusdev/xmcore/common/protocol/dns" 18 "github.com/xmplusdev/xmcore/common/session" 19 "github.com/xmplusdev/xmcore/common/signal/pubsub" 20 "github.com/xmplusdev/xmcore/common/task" 21 dns_feature "github.com/xmplusdev/xmcore/features/dns" 22 "github.com/xmplusdev/xmcore/transport/internet/tls" 23 "golang.org/x/net/dns/dnsmessage" 24 "golang.org/x/net/http2" 25 ) 26 27 // NextProtoDQ - During connection establishment, DNS/QUIC support is indicated 28 // by selecting the ALPN token "dq" in the crypto handshake. 29 const NextProtoDQ = "doq" 30 31 const handshakeTimeout = time.Second * 8 32 33 // QUICNameServer implemented DNS over QUIC 34 type QUICNameServer struct { 35 sync.RWMutex 36 ips map[string]*record 37 pub *pubsub.Service 38 cleanup *task.Periodic 39 reqID uint32 40 name string 41 destination *net.Destination 42 connection quic.Connection 43 queryStrategy QueryStrategy 44 } 45 46 // NewQUICNameServer creates DNS-over-QUIC client object for local resolving 47 func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServer, error) { 48 newError("DNS: created Local DNS-over-QUIC client for ", url.String()).AtInfo().WriteToLog() 49 50 var err error 51 port := net.Port(853) 52 if url.Port() != "" { 53 port, err = net.PortFromString(url.Port()) 54 if err != nil { 55 return nil, err 56 } 57 } 58 dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port) 59 60 s := &QUICNameServer{ 61 ips: make(map[string]*record), 62 pub: pubsub.NewService(), 63 name: url.String(), 64 destination: &dest, 65 queryStrategy: queryStrategy, 66 } 67 s.cleanup = &task.Periodic{ 68 Interval: time.Minute, 69 Execute: s.Cleanup, 70 } 71 72 return s, nil 73 } 74 75 // Name returns client name 76 func (s *QUICNameServer) Name() string { 77 return s.name 78 } 79 80 // Cleanup clears expired items from cache 81 func (s *QUICNameServer) Cleanup() error { 82 now := time.Now() 83 s.Lock() 84 defer s.Unlock() 85 86 if len(s.ips) == 0 { 87 return newError("nothing to do. stopping...") 88 } 89 90 for domain, record := range s.ips { 91 if record.A != nil && record.A.Expire.Before(now) { 92 record.A = nil 93 } 94 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 95 record.AAAA = nil 96 } 97 98 if record.A == nil && record.AAAA == nil { 99 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 100 delete(s.ips, domain) 101 } else { 102 s.ips[domain] = record 103 } 104 } 105 106 if len(s.ips) == 0 { 107 s.ips = make(map[string]*record) 108 } 109 110 return nil 111 } 112 113 func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 114 elapsed := time.Since(req.start) 115 116 s.Lock() 117 rec, found := s.ips[req.domain] 118 if !found { 119 rec = &record{} 120 } 121 updated := false 122 123 switch req.reqType { 124 case dnsmessage.TypeA: 125 if isNewer(rec.A, ipRec) { 126 rec.A = ipRec 127 updated = true 128 } 129 case dnsmessage.TypeAAAA: 130 addr := make([]net.Address, 0) 131 for _, ip := range ipRec.IP { 132 if len(ip.IP()) == net.IPv6len { 133 addr = append(addr, ip) 134 } 135 } 136 ipRec.IP = addr 137 if isNewer(rec.AAAA, ipRec) { 138 rec.AAAA = ipRec 139 updated = true 140 } 141 } 142 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 143 144 if updated { 145 s.ips[req.domain] = rec 146 } 147 switch req.reqType { 148 case dnsmessage.TypeA: 149 s.pub.Publish(req.domain+"4", nil) 150 case dnsmessage.TypeAAAA: 151 s.pub.Publish(req.domain+"6", nil) 152 } 153 s.Unlock() 154 common.Must(s.cleanup.Start()) 155 } 156 157 func (s *QUICNameServer) newReqID() uint16 { 158 return uint16(atomic.AddUint32(&s.reqID, 1)) 159 } 160 161 func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 162 newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx)) 163 164 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 165 166 var deadline time.Time 167 if d, ok := ctx.Deadline(); ok { 168 deadline = d 169 } else { 170 deadline = time.Now().Add(time.Second * 5) 171 } 172 173 for _, req := range reqs { 174 go func(r *dnsRequest) { 175 // generate new context for each req, using same context 176 // may cause reqs all aborted if any one encounter an error 177 dnsCtx := ctx 178 179 // reserve internal dns server requested Inbound 180 if inbound := session.InboundFromContext(ctx); inbound != nil { 181 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 182 } 183 184 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 185 Protocol: "quic", 186 SkipDNSResolve: true, 187 }) 188 189 var cancel context.CancelFunc 190 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 191 defer cancel() 192 193 b, err := dns.PackMessage(r.msg) 194 if err != nil { 195 newError("failed to pack dns query").Base(err).AtError().WriteToLog() 196 return 197 } 198 199 dnsReqBuf := buf.New() 200 binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len())) 201 dnsReqBuf.Write(b.Bytes()) 202 b.Release() 203 204 conn, err := s.openStream(dnsCtx) 205 if err != nil { 206 newError("failed to open quic connection").Base(err).AtError().WriteToLog() 207 return 208 } 209 210 _, err = conn.Write(dnsReqBuf.Bytes()) 211 if err != nil { 212 newError("failed to send query").Base(err).AtError().WriteToLog() 213 return 214 } 215 216 _ = conn.Close() 217 218 respBuf := buf.New() 219 defer respBuf.Release() 220 n, err := respBuf.ReadFullFrom(conn, 2) 221 if err != nil && n == 0 { 222 newError("failed to read response length").Base(err).AtError().WriteToLog() 223 return 224 } 225 var length int16 226 err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length) 227 if err != nil { 228 newError("failed to parse response length").Base(err).AtError().WriteToLog() 229 return 230 } 231 respBuf.Clear() 232 n, err = respBuf.ReadFullFrom(conn, int32(length)) 233 if err != nil && n == 0 { 234 newError("failed to read response length").Base(err).AtError().WriteToLog() 235 return 236 } 237 238 rec, err := parseResponse(respBuf.Bytes()) 239 if err != nil { 240 newError("failed to handle response").Base(err).AtError().WriteToLog() 241 return 242 } 243 s.updateIP(r, rec) 244 }(req) 245 } 246 } 247 248 func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 249 s.RLock() 250 record, found := s.ips[domain] 251 s.RUnlock() 252 253 if !found { 254 return nil, errRecordNotFound 255 } 256 257 var err4 error 258 var err6 error 259 var ips []net.Address 260 var ip6 []net.Address 261 262 if option.IPv4Enable { 263 ips, err4 = record.A.getIPs() 264 } 265 266 if option.IPv6Enable { 267 ip6, err6 = record.AAAA.getIPs() 268 ips = append(ips, ip6...) 269 } 270 271 if len(ips) > 0 { 272 return toNetIP(ips) 273 } 274 275 if err4 != nil { 276 return nil, err4 277 } 278 279 if err6 != nil { 280 return nil, err6 281 } 282 283 if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { 284 return nil, dns_feature.ErrEmptyResponse 285 } 286 287 return nil, errRecordNotFound 288 } 289 290 // QueryIP is called from dns.Server->queryIPTimeout 291 func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { 292 fqdn := Fqdn(domain) 293 option = ResolveIpOptionOverride(s.queryStrategy, option) 294 if !option.IPv4Enable && !option.IPv6Enable { 295 return nil, dns_feature.ErrEmptyResponse 296 } 297 298 if disableCache { 299 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 300 } else { 301 ips, err := s.findIPsForDomain(fqdn, option) 302 if err != errRecordNotFound { 303 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 304 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) 305 return ips, err 306 } 307 } 308 309 // ipv4 and ipv6 belong to different subscription groups 310 var sub4, sub6 *pubsub.Subscriber 311 if option.IPv4Enable { 312 sub4 = s.pub.Subscribe(fqdn + "4") 313 defer sub4.Close() 314 } 315 if option.IPv6Enable { 316 sub6 = s.pub.Subscribe(fqdn + "6") 317 defer sub6.Close() 318 } 319 done := make(chan interface{}) 320 go func() { 321 if sub4 != nil { 322 select { 323 case <-sub4.Wait(): 324 case <-ctx.Done(): 325 } 326 } 327 if sub6 != nil { 328 select { 329 case <-sub6.Wait(): 330 case <-ctx.Done(): 331 } 332 } 333 close(done) 334 }() 335 s.sendQuery(ctx, fqdn, clientIP, option) 336 start := time.Now() 337 338 for { 339 ips, err := s.findIPsForDomain(fqdn, option) 340 if err != errRecordNotFound { 341 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) 342 return ips, err 343 } 344 345 select { 346 case <-ctx.Done(): 347 return nil, ctx.Err() 348 case <-done: 349 } 350 } 351 } 352 353 func isActive(s quic.Connection) bool { 354 select { 355 case <-s.Context().Done(): 356 return false 357 default: 358 return true 359 } 360 } 361 362 func (s *QUICNameServer) getConnection() (quic.Connection, error) { 363 var conn quic.Connection 364 s.RLock() 365 conn = s.connection 366 if conn != nil && isActive(conn) { 367 s.RUnlock() 368 return conn, nil 369 } 370 if conn != nil { 371 // we're recreating the connection, let's create a new one 372 _ = conn.CloseWithError(0, "") 373 } 374 s.RUnlock() 375 376 s.Lock() 377 defer s.Unlock() 378 379 var err error 380 conn, err = s.openConnection() 381 if err != nil { 382 // This does not look too nice, but QUIC (or maybe quic-go) 383 // doesn't seem stable enough. 384 // Maybe retransmissions aren't fully implemented in quic-go? 385 // Anyways, the simple solution is to make a second try when 386 // it fails to open the QUIC connection. 387 conn, err = s.openConnection() 388 if err != nil { 389 return nil, err 390 } 391 } 392 s.connection = conn 393 return conn, nil 394 } 395 396 func (s *QUICNameServer) openConnection() (quic.Connection, error) { 397 tlsConfig := tls.Config{} 398 quicConfig := &quic.Config{ 399 HandshakeIdleTimeout: handshakeTimeout, 400 } 401 tlsConfig.ServerName = s.destination.Address.String() 402 conn, err := quic.DialAddr(context.Background(), s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto("http/1.1", http2.NextProtoTLS, NextProtoDQ)), quicConfig) 403 log.Record(&log.AccessMessage{ 404 From: "DNS", 405 To: s.destination, 406 Status: log.AccessAccepted, 407 Detour: "local", 408 }) 409 if err != nil { 410 return nil, err 411 } 412 413 return conn, nil 414 } 415 416 func (s *QUICNameServer) openStream(ctx context.Context) (quic.Stream, error) { 417 conn, err := s.getConnection() 418 if err != nil { 419 return nil, err 420 } 421 422 // open a new stream 423 return conn.OpenStreamSync(ctx) 424 }