github.com/v2fly/v2ray-core/v4@v4.45.2/app/dns/nameserver_quic.go (about)

     1  //go:build !confonly
     2  // +build !confonly
     3  
     4  package dns
     5  
     6  import (
     7  	"context"
     8  	"net/url"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/lucas-clemente/quic-go"
    14  	"golang.org/x/net/dns/dnsmessage"
    15  	"golang.org/x/net/http2"
    16  
    17  	"github.com/v2fly/v2ray-core/v4/common"
    18  	"github.com/v2fly/v2ray-core/v4/common/buf"
    19  	"github.com/v2fly/v2ray-core/v4/common/net"
    20  	"github.com/v2fly/v2ray-core/v4/common/protocol/dns"
    21  	"github.com/v2fly/v2ray-core/v4/common/session"
    22  	"github.com/v2fly/v2ray-core/v4/common/signal/pubsub"
    23  	"github.com/v2fly/v2ray-core/v4/common/task"
    24  	dns_feature "github.com/v2fly/v2ray-core/v4/features/dns"
    25  	"github.com/v2fly/v2ray-core/v4/transport/internet/tls"
    26  )
    27  
    28  // NextProtoDQ - During connection establishment, DNS/QUIC support is indicated
    29  // by selecting the ALPN token "dq" in the crypto handshake.
    30  const NextProtoDQ = "doq-i00"
    31  
    32  const handshakeIdleTimeout = time.Second * 8
    33  
    34  // QUICNameServer implemented DNS over QUIC
    35  type QUICNameServer struct {
    36  	sync.RWMutex
    37  	ips         map[string]record
    38  	pub         *pubsub.Service
    39  	cleanup     *task.Periodic
    40  	reqID       uint32
    41  	name        string
    42  	destination net.Destination
    43  	connection  quic.Connection
    44  }
    45  
    46  // NewQUICNameServer creates DNS-over-QUIC client object for local resolving
    47  func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) {
    48  	newError("DNS: created Local DNS-over-QUIC client for ", url.String()).AtInfo().WriteToLog()
    49  
    50  	var err error
    51  	port := net.Port(784)
    52  	if url.Port() != "" {
    53  		port, err = net.PortFromString(url.Port())
    54  		if err != nil {
    55  			return nil, err
    56  		}
    57  	}
    58  	dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port)
    59  
    60  	s := &QUICNameServer{
    61  		ips:         make(map[string]record),
    62  		pub:         pubsub.NewService(),
    63  		name:        url.String(),
    64  		destination: dest,
    65  	}
    66  	s.cleanup = &task.Periodic{
    67  		Interval: time.Minute,
    68  		Execute:  s.Cleanup,
    69  	}
    70  
    71  	return s, nil
    72  }
    73  
    74  // Name returns client name
    75  func (s *QUICNameServer) Name() string {
    76  	return s.name
    77  }
    78  
    79  // Cleanup clears expired items from cache
    80  func (s *QUICNameServer) Cleanup() error {
    81  	now := time.Now()
    82  	s.Lock()
    83  	defer s.Unlock()
    84  
    85  	if len(s.ips) == 0 {
    86  		return newError("nothing to do. stopping...")
    87  	}
    88  
    89  	for domain, record := range s.ips {
    90  		if record.A != nil && record.A.Expire.Before(now) {
    91  			record.A = nil
    92  		}
    93  		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
    94  			record.AAAA = nil
    95  		}
    96  
    97  		if record.A == nil && record.AAAA == nil {
    98  			newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
    99  			delete(s.ips, domain)
   100  		} else {
   101  			s.ips[domain] = record
   102  		}
   103  	}
   104  
   105  	if len(s.ips) == 0 {
   106  		s.ips = make(map[string]record)
   107  	}
   108  
   109  	return nil
   110  }
   111  
   112  func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
   113  	elapsed := time.Since(req.start)
   114  
   115  	s.Lock()
   116  	rec := s.ips[req.domain]
   117  	updated := false
   118  
   119  	switch req.reqType {
   120  	case dnsmessage.TypeA:
   121  		if isNewer(rec.A, ipRec) {
   122  			rec.A = ipRec
   123  			updated = true
   124  		}
   125  	case dnsmessage.TypeAAAA:
   126  		addr := make([]net.Address, 0)
   127  		for _, ip := range ipRec.IP {
   128  			if len(ip.IP()) == net.IPv6len {
   129  				addr = append(addr, ip)
   130  			}
   131  		}
   132  		ipRec.IP = addr
   133  		if isNewer(rec.AAAA, ipRec) {
   134  			rec.AAAA = ipRec
   135  			updated = true
   136  		}
   137  	}
   138  	newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
   139  
   140  	if updated {
   141  		s.ips[req.domain] = rec
   142  	}
   143  	switch req.reqType {
   144  	case dnsmessage.TypeA:
   145  		s.pub.Publish(req.domain+"4", nil)
   146  	case dnsmessage.TypeAAAA:
   147  		s.pub.Publish(req.domain+"6", nil)
   148  	}
   149  	s.Unlock()
   150  	common.Must(s.cleanup.Start())
   151  }
   152  
   153  func (s *QUICNameServer) newReqID() uint16 {
   154  	return uint16(atomic.AddUint32(&s.reqID, 1))
   155  }
   156  
   157  func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
   158  	newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx))
   159  
   160  	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))
   161  
   162  	var deadline time.Time
   163  	if d, ok := ctx.Deadline(); ok {
   164  		deadline = d
   165  	} else {
   166  		deadline = time.Now().Add(time.Second * 5)
   167  	}
   168  
   169  	for _, req := range reqs {
   170  		go func(r *dnsRequest) {
   171  			// generate new context for each req, using same context
   172  			// may cause reqs all aborted if any one encounter an error
   173  			dnsCtx := ctx
   174  
   175  			// reserve internal dns server requested Inbound
   176  			if inbound := session.InboundFromContext(ctx); inbound != nil {
   177  				dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
   178  			}
   179  
   180  			dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
   181  				Protocol:       "quic",
   182  				SkipDNSResolve: true,
   183  			})
   184  
   185  			var cancel context.CancelFunc
   186  			dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline)
   187  			defer cancel()
   188  
   189  			b, err := dns.PackMessage(r.msg)
   190  			if err != nil {
   191  				newError("failed to pack dns query").Base(err).AtError().WriteToLog()
   192  				return
   193  			}
   194  
   195  			conn, err := s.openStream(dnsCtx)
   196  			if err != nil {
   197  				newError("failed to open quic connection").Base(err).AtError().WriteToLog()
   198  				return
   199  			}
   200  
   201  			_, err = conn.Write(b.Bytes())
   202  			if err != nil {
   203  				newError("failed to send query").Base(err).AtError().WriteToLog()
   204  				return
   205  			}
   206  
   207  			_ = conn.Close()
   208  
   209  			respBuf := buf.New()
   210  			defer respBuf.Release()
   211  			n, err := respBuf.ReadFrom(conn)
   212  			if err != nil && n == 0 {
   213  				newError("failed to read response").Base(err).AtError().WriteToLog()
   214  				return
   215  			}
   216  
   217  			rec, err := parseResponse(respBuf.Bytes())
   218  			if err != nil {
   219  				newError("failed to handle response").Base(err).AtError().WriteToLog()
   220  				return
   221  			}
   222  			s.updateIP(r, rec)
   223  		}(req)
   224  	}
   225  }
   226  
   227  func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) {
   228  	s.RLock()
   229  	record, found := s.ips[domain]
   230  	s.RUnlock()
   231  
   232  	if !found {
   233  		return nil, errRecordNotFound
   234  	}
   235  
   236  	var err4 error
   237  	var err6 error
   238  	var ips []net.Address
   239  	var ip6 []net.Address
   240  
   241  	if option.IPv4Enable {
   242  		ips, err4 = record.A.getIPs()
   243  	}
   244  
   245  	if option.IPv6Enable {
   246  		ip6, err6 = record.AAAA.getIPs()
   247  		ips = append(ips, ip6...)
   248  	}
   249  
   250  	if len(ips) > 0 {
   251  		return toNetIP(ips)
   252  	}
   253  
   254  	if err4 != nil {
   255  		return nil, err4
   256  	}
   257  
   258  	if err6 != nil {
   259  		return nil, err6
   260  	}
   261  
   262  	if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
   263  		return nil, dns_feature.ErrEmptyResponse
   264  	}
   265  
   266  	return nil, errRecordNotFound
   267  }
   268  
   269  // QueryIP is called from dns.Server->queryIPTimeout
   270  func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) {
   271  	fqdn := Fqdn(domain)
   272  
   273  	if disableCache {
   274  		newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
   275  	} else {
   276  		ips, err := s.findIPsForDomain(fqdn, option)
   277  		if err != errRecordNotFound {
   278  			newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
   279  			return ips, err
   280  		}
   281  	}
   282  
   283  	// ipv4 and ipv6 belong to different subscription groups
   284  	var sub4, sub6 *pubsub.Subscriber
   285  	if option.IPv4Enable {
   286  		sub4 = s.pub.Subscribe(fqdn + "4")
   287  		defer sub4.Close()
   288  	}
   289  	if option.IPv6Enable {
   290  		sub6 = s.pub.Subscribe(fqdn + "6")
   291  		defer sub6.Close()
   292  	}
   293  	done := make(chan interface{})
   294  	go func() {
   295  		if sub4 != nil {
   296  			select {
   297  			case <-sub4.Wait():
   298  			case <-ctx.Done():
   299  			}
   300  		}
   301  		if sub6 != nil {
   302  			select {
   303  			case <-sub6.Wait():
   304  			case <-ctx.Done():
   305  			}
   306  		}
   307  		close(done)
   308  	}()
   309  	s.sendQuery(ctx, fqdn, clientIP, option)
   310  
   311  	for {
   312  		ips, err := s.findIPsForDomain(fqdn, option)
   313  		if err != errRecordNotFound {
   314  			return ips, err
   315  		}
   316  
   317  		select {
   318  		case <-ctx.Done():
   319  			return nil, ctx.Err()
   320  		case <-done:
   321  		}
   322  	}
   323  }
   324  
   325  func isActive(s quic.Connection) bool {
   326  	select {
   327  	case <-s.Context().Done():
   328  		return false
   329  	default:
   330  		return true
   331  	}
   332  }
   333  
   334  func (s *QUICNameServer) getConnection(ctx context.Context) (quic.Connection, error) {
   335  	var conn quic.Connection
   336  	s.RLock()
   337  	conn = s.connection
   338  	if conn != nil && isActive(conn) {
   339  		s.RUnlock()
   340  		return conn, nil
   341  	}
   342  	if conn != nil {
   343  		// we're recreating the connection, let's create a new one
   344  		_ = conn.CloseWithError(0, "")
   345  	}
   346  	s.RUnlock()
   347  
   348  	s.Lock()
   349  	defer s.Unlock()
   350  
   351  	var err error
   352  	conn, err = s.openConnection(ctx)
   353  	if err != nil {
   354  		// This does not look too nice, but QUIC (or maybe quic-go)
   355  		// doesn't seem stable enough.
   356  		// Maybe retransmissions aren't fully implemented in quic-go?
   357  		// Anyways, the simple solution is to make a second try when
   358  		// it fails to open the QUIC connection.
   359  		conn, err = s.openConnection(ctx)
   360  		if err != nil {
   361  			return nil, err
   362  		}
   363  	}
   364  	s.connection = conn
   365  	return conn, nil
   366  }
   367  
   368  func (s *QUICNameServer) openConnection(ctx context.Context) (quic.Connection, error) {
   369  	tlsConfig := tls.Config{}
   370  	quicConfig := &quic.Config{
   371  		HandshakeIdleTimeout: handshakeIdleTimeout,
   372  	}
   373  
   374  	conn, err := quic.DialAddrContext(ctx, s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto("http/1.1", http2.NextProtoTLS, NextProtoDQ)), quicConfig)
   375  	if err != nil {
   376  		return nil, err
   377  	}
   378  
   379  	return conn, nil
   380  }
   381  
   382  func (s *QUICNameServer) openStream(ctx context.Context) (quic.Stream, error) {
   383  	conn, err := s.getConnection(ctx)
   384  	if err != nil {
   385  		return nil, err
   386  	}
   387  
   388  	// open a new stream
   389  	return conn.OpenStreamSync(ctx)
   390  }