github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/app/dns/nameserver_quic.go (about)

     1  package dns
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/binary"
     7  	"net/url"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/quic-go/quic-go"
    13  	"golang.org/x/net/dns/dnsmessage"
    14  
    15  	"github.com/v2fly/v2ray-core/v5/common"
    16  	"github.com/v2fly/v2ray-core/v5/common/buf"
    17  	"github.com/v2fly/v2ray-core/v5/common/net"
    18  	"github.com/v2fly/v2ray-core/v5/common/protocol/dns"
    19  	"github.com/v2fly/v2ray-core/v5/common/session"
    20  	"github.com/v2fly/v2ray-core/v5/common/signal/pubsub"
    21  	"github.com/v2fly/v2ray-core/v5/common/task"
    22  	dns_feature "github.com/v2fly/v2ray-core/v5/features/dns"
    23  	"github.com/v2fly/v2ray-core/v5/transport/internet/tls"
    24  )
    25  
    26  // NextProtoDQ - During connection establishment, DNS/QUIC support is indicated
    27  // by selecting the ALPN token "doq" in the crypto handshake.
    28  const NextProtoDQ = "doq"
    29  
    30  const handshakeIdleTimeout = time.Second * 8
    31  
    32  // QUICNameServer implemented DNS over QUIC
    33  type QUICNameServer struct {
    34  	sync.RWMutex
    35  	ips         map[string]record
    36  	pub         *pubsub.Service
    37  	cleanup     *task.Periodic
    38  	reqID       uint32
    39  	name        string
    40  	destination net.Destination
    41  	connection  quic.Connection
    42  }
    43  
    44  // NewQUICNameServer creates DNS-over-QUIC client object for local resolving
    45  func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) {
    46  	newError("DNS: created Local DNS-over-QUIC client for ", url.String()).AtInfo().WriteToLog()
    47  
    48  	var err error
    49  	port := net.Port(853)
    50  	if url.Port() != "" {
    51  		port, err = net.PortFromString(url.Port())
    52  		if err != nil {
    53  			return nil, err
    54  		}
    55  	}
    56  	dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port)
    57  
    58  	s := &QUICNameServer{
    59  		ips:         make(map[string]record),
    60  		pub:         pubsub.NewService(),
    61  		name:        url.String(),
    62  		destination: dest,
    63  	}
    64  	s.cleanup = &task.Periodic{
    65  		Interval: time.Minute,
    66  		Execute:  s.Cleanup,
    67  	}
    68  
    69  	return s, nil
    70  }
    71  
    72  // Name returns client name
    73  func (s *QUICNameServer) Name() string {
    74  	return s.name
    75  }
    76  
    77  // Cleanup clears expired items from cache
    78  func (s *QUICNameServer) Cleanup() error {
    79  	now := time.Now()
    80  	s.Lock()
    81  	defer s.Unlock()
    82  
    83  	if len(s.ips) == 0 {
    84  		return newError("nothing to do. stopping...")
    85  	}
    86  
    87  	for domain, record := range s.ips {
    88  		if record.A != nil && record.A.Expire.Before(now) {
    89  			record.A = nil
    90  		}
    91  		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
    92  			record.AAAA = nil
    93  		}
    94  
    95  		if record.A == nil && record.AAAA == nil {
    96  			newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
    97  			delete(s.ips, domain)
    98  		} else {
    99  			s.ips[domain] = record
   100  		}
   101  	}
   102  
   103  	if len(s.ips) == 0 {
   104  		s.ips = make(map[string]record)
   105  	}
   106  
   107  	return nil
   108  }
   109  
   110  func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
   111  	elapsed := time.Since(req.start)
   112  
   113  	s.Lock()
   114  	rec := s.ips[req.domain]
   115  	updated := false
   116  
   117  	switch req.reqType {
   118  	case dnsmessage.TypeA:
   119  		if isNewer(rec.A, ipRec) {
   120  			rec.A = ipRec
   121  			updated = true
   122  		}
   123  	case dnsmessage.TypeAAAA:
   124  		addr := make([]net.Address, 0)
   125  		for _, ip := range ipRec.IP {
   126  			if len(ip.IP()) == net.IPv6len {
   127  				addr = append(addr, ip)
   128  			}
   129  		}
   130  		ipRec.IP = addr
   131  		if isNewer(rec.AAAA, ipRec) {
   132  			rec.AAAA = ipRec
   133  			updated = true
   134  		}
   135  	}
   136  	newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
   137  
   138  	if updated {
   139  		s.ips[req.domain] = rec
   140  	}
   141  	switch req.reqType {
   142  	case dnsmessage.TypeA:
   143  		s.pub.Publish(req.domain+"4", nil)
   144  	case dnsmessage.TypeAAAA:
   145  		s.pub.Publish(req.domain+"6", nil)
   146  	}
   147  	s.Unlock()
   148  	common.Must(s.cleanup.Start())
   149  }
   150  
   151  func (s *QUICNameServer) newReqID() uint16 {
   152  	return uint16(atomic.AddUint32(&s.reqID, 1))
   153  }
   154  
   155  func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
   156  	newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx))
   157  
   158  	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))
   159  
   160  	var deadline time.Time
   161  	if d, ok := ctx.Deadline(); ok {
   162  		deadline = d
   163  	} else {
   164  		deadline = time.Now().Add(time.Second * 5)
   165  	}
   166  
   167  	for _, req := range reqs {
   168  		go func(r *dnsRequest) {
   169  			// generate new context for each req, using same context
   170  			// may cause reqs all aborted if any one encounter an error
   171  			dnsCtx := ctx
   172  
   173  			// reserve internal dns server requested Inbound
   174  			if inbound := session.InboundFromContext(ctx); inbound != nil {
   175  				dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
   176  			}
   177  
   178  			dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
   179  				Protocol:       "quic",
   180  				SkipDNSResolve: true,
   181  			})
   182  
   183  			var cancel context.CancelFunc
   184  			dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline)
   185  			defer cancel()
   186  
   187  			b, err := dns.PackMessage(r.msg)
   188  			if err != nil {
   189  				newError("failed to pack dns query").Base(err).AtError().WriteToLog()
   190  				return
   191  			}
   192  
   193  			dnsReqBuf := buf.New()
   194  			binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
   195  			dnsReqBuf.Write(b.Bytes())
   196  			b.Release()
   197  
   198  			conn, err := s.openStream(dnsCtx)
   199  			if err != nil {
   200  				newError("failed to open quic connection").Base(err).AtError().WriteToLog()
   201  				return
   202  			}
   203  
   204  			_, err = conn.Write(dnsReqBuf.Bytes())
   205  			if err != nil {
   206  				newError("failed to send query").Base(err).AtError().WriteToLog()
   207  				return
   208  			}
   209  
   210  			_ = conn.Close()
   211  
   212  			respBuf := buf.New()
   213  			defer respBuf.Release()
   214  			n, err := respBuf.ReadFullFrom(conn, 2)
   215  			if err != nil && n == 0 {
   216  				newError("failed to read response length").Base(err).AtError().WriteToLog()
   217  				return
   218  			}
   219  			var length int16
   220  			err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)
   221  			if err != nil {
   222  				newError("failed to parse response length").Base(err).AtError().WriteToLog()
   223  				return
   224  			}
   225  			respBuf.Clear()
   226  			n, err = respBuf.ReadFullFrom(conn, int32(length))
   227  			if err != nil && n == 0 {
   228  				newError("failed to read response length").Base(err).AtError().WriteToLog()
   229  				return
   230  			}
   231  
   232  			rec, err := parseResponse(respBuf.Bytes())
   233  			if err != nil {
   234  				newError("failed to handle response").Base(err).AtError().WriteToLog()
   235  				return
   236  			}
   237  			s.updateIP(r, rec)
   238  		}(req)
   239  	}
   240  }
   241  
   242  func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) {
   243  	s.RLock()
   244  	record, found := s.ips[domain]
   245  	s.RUnlock()
   246  
   247  	if !found {
   248  		return nil, errRecordNotFound
   249  	}
   250  
   251  	var ips []net.Address
   252  	var lastErr error
   253  	if option.IPv4Enable {
   254  		a, err := record.A.getIPs()
   255  		if err != nil {
   256  			lastErr = err
   257  		}
   258  		ips = append(ips, a...)
   259  	}
   260  
   261  	if option.IPv6Enable {
   262  		aaaa, err := record.AAAA.getIPs()
   263  		if err != nil {
   264  			lastErr = err
   265  		}
   266  		ips = append(ips, aaaa...)
   267  	}
   268  
   269  	if len(ips) > 0 {
   270  		return toNetIP(ips)
   271  	}
   272  
   273  	if lastErr != nil {
   274  		return nil, lastErr
   275  	}
   276  
   277  	return nil, dns_feature.ErrEmptyResponse
   278  }
   279  
   280  // QueryIP is called from dns.Server->queryIPTimeout
   281  func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) {
   282  	fqdn := Fqdn(domain)
   283  
   284  	if disableCache {
   285  		newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
   286  	} else {
   287  		ips, err := s.findIPsForDomain(fqdn, option)
   288  		if err != errRecordNotFound {
   289  			newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
   290  			return ips, err
   291  		}
   292  	}
   293  
   294  	// ipv4 and ipv6 belong to different subscription groups
   295  	var sub4, sub6 *pubsub.Subscriber
   296  	if option.IPv4Enable {
   297  		sub4 = s.pub.Subscribe(fqdn + "4")
   298  		defer sub4.Close()
   299  	}
   300  	if option.IPv6Enable {
   301  		sub6 = s.pub.Subscribe(fqdn + "6")
   302  		defer sub6.Close()
   303  	}
   304  	done := make(chan interface{})
   305  	go func() {
   306  		if sub4 != nil {
   307  			select {
   308  			case <-sub4.Wait():
   309  			case <-ctx.Done():
   310  			}
   311  		}
   312  		if sub6 != nil {
   313  			select {
   314  			case <-sub6.Wait():
   315  			case <-ctx.Done():
   316  			}
   317  		}
   318  		close(done)
   319  	}()
   320  	s.sendQuery(ctx, fqdn, clientIP, option)
   321  
   322  	for {
   323  		ips, err := s.findIPsForDomain(fqdn, option)
   324  		if err != errRecordNotFound {
   325  			return ips, err
   326  		}
   327  
   328  		select {
   329  		case <-ctx.Done():
   330  			return nil, ctx.Err()
   331  		case <-done:
   332  		}
   333  	}
   334  }
   335  
   336  func isActive(s quic.Connection) bool {
   337  	select {
   338  	case <-s.Context().Done():
   339  		return false
   340  	default:
   341  		return true
   342  	}
   343  }
   344  
   345  func (s *QUICNameServer) getConnection(ctx context.Context) (quic.Connection, error) {
   346  	var conn quic.Connection
   347  	s.RLock()
   348  	conn = s.connection
   349  	if conn != nil && isActive(conn) {
   350  		s.RUnlock()
   351  		return conn, nil
   352  	}
   353  	if conn != nil {
   354  		// we're recreating the connection, let's create a new one
   355  		_ = conn.CloseWithError(0, "")
   356  	}
   357  	s.RUnlock()
   358  
   359  	s.Lock()
   360  	defer s.Unlock()
   361  
   362  	var err error
   363  	conn, err = s.openConnection(ctx)
   364  	if err != nil {
   365  		// This does not look too nice, but QUIC (or maybe quic-go)
   366  		// doesn't seem stable enough.
   367  		// Maybe retransmissions aren't fully implemented in quic-go?
   368  		// Anyways, the simple solution is to make a second try when
   369  		// it fails to open the QUIC connection.
   370  		conn, err = s.openConnection(ctx)
   371  		if err != nil {
   372  			return nil, err
   373  		}
   374  	}
   375  	s.connection = conn
   376  	return conn, nil
   377  }
   378  
   379  func (s *QUICNameServer) openConnection(ctx context.Context) (quic.Connection, error) {
   380  	tlsConfig := tls.Config{
   381  		ServerName: func() string {
   382  			switch s.destination.Address.Family() {
   383  			case net.AddressFamilyIPv4, net.AddressFamilyIPv6:
   384  				return s.destination.Address.IP().String()
   385  			case net.AddressFamilyDomain:
   386  				return s.destination.Address.Domain()
   387  			default:
   388  				panic("unknown address family")
   389  			}
   390  		}(),
   391  	}
   392  	quicConfig := &quic.Config{
   393  		HandshakeIdleTimeout: handshakeIdleTimeout,
   394  	}
   395  
   396  	conn, err := quic.DialAddr(ctx, s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto(NextProtoDQ)), quicConfig)
   397  	if err != nil {
   398  		return nil, err
   399  	}
   400  
   401  	return conn, nil
   402  }
   403  
   404  func (s *QUICNameServer) openStream(ctx context.Context) (quic.Stream, error) {
   405  	conn, err := s.getConnection(ctx)
   406  	if err != nil {
   407  		return nil, err
   408  	}
   409  
   410  	// open a new stream
   411  	return conn.OpenStreamSync(ctx)
   412  }