github.com/xraypb/xray-core@v1.6.6/app/dns/nameserver_quic.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"net/url"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/lucas-clemente/quic-go"
    11  	"github.com/xraypb/xray-core/common"
    12  	"github.com/xraypb/xray-core/common/buf"
    13  	"github.com/xraypb/xray-core/common/log"
    14  	"github.com/xraypb/xray-core/common/net"
    15  	"github.com/xraypb/xray-core/common/protocol/dns"
    16  	"github.com/xraypb/xray-core/common/session"
    17  	"github.com/xraypb/xray-core/common/signal/pubsub"
    18  	"github.com/xraypb/xray-core/common/task"
    19  	dns_feature "github.com/xraypb/xray-core/features/dns"
    20  	"github.com/xraypb/xray-core/transport/internet/tls"
    21  	"golang.org/x/net/dns/dnsmessage"
    22  	"golang.org/x/net/http2"
    23  )
    24  
    25  // NextProtoDQ - During connection establishment, DNS/QUIC support is indicated
    26  // by selecting the ALPN token "dq" in the crypto handshake.
    27  const NextProtoDQ = "doq-i00"
    28  
    29  const handshakeTimeout = time.Second * 8
    30  
    31  // QUICNameServer implemented DNS over QUIC
    32  type QUICNameServer struct {
    33  	sync.RWMutex
    34  	ips         map[string]*record
    35  	pub         *pubsub.Service
    36  	cleanup     *task.Periodic
    37  	reqID       uint32
    38  	name        string
    39  	destination *net.Destination
    40  	connection  quic.Connection
    41  }
    42  
    43  // NewQUICNameServer creates DNS-over-QUIC client object for local resolving
    44  func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) {
    45  	newError("DNS: created Local DNS-over-QUIC client for ", url.String()).AtInfo().WriteToLog()
    46  
    47  	var err error
    48  	port := net.Port(784)
    49  	if url.Port() != "" {
    50  		port, err = net.PortFromString(url.Port())
    51  		if err != nil {
    52  			return nil, err
    53  		}
    54  	}
    55  	dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port)
    56  
    57  	s := &QUICNameServer{
    58  		ips:         make(map[string]*record),
    59  		pub:         pubsub.NewService(),
    60  		name:        url.String(),
    61  		destination: &dest,
    62  	}
    63  	s.cleanup = &task.Periodic{
    64  		Interval: time.Minute,
    65  		Execute:  s.Cleanup,
    66  	}
    67  
    68  	return s, nil
    69  }
    70  
    71  // Name returns client name
    72  func (s *QUICNameServer) Name() string {
    73  	return s.name
    74  }
    75  
    76  // Cleanup clears expired items from cache
    77  func (s *QUICNameServer) Cleanup() error {
    78  	now := time.Now()
    79  	s.Lock()
    80  	defer s.Unlock()
    81  
    82  	if len(s.ips) == 0 {
    83  		return newError("nothing to do. stopping...")
    84  	}
    85  
    86  	for domain, record := range s.ips {
    87  		if record.A != nil && record.A.Expire.Before(now) {
    88  			record.A = nil
    89  		}
    90  		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
    91  			record.AAAA = nil
    92  		}
    93  
    94  		if record.A == nil && record.AAAA == nil {
    95  			newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
    96  			delete(s.ips, domain)
    97  		} else {
    98  			s.ips[domain] = record
    99  		}
   100  	}
   101  
   102  	if len(s.ips) == 0 {
   103  		s.ips = make(map[string]*record)
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
   110  	elapsed := time.Since(req.start)
   111  
   112  	s.Lock()
   113  	rec, found := s.ips[req.domain]
   114  	if !found {
   115  		rec = &record{}
   116  	}
   117  	updated := false
   118  
   119  	switch req.reqType {
   120  	case dnsmessage.TypeA:
   121  		if isNewer(rec.A, ipRec) {
   122  			rec.A = ipRec
   123  			updated = true
   124  		}
   125  	case dnsmessage.TypeAAAA:
   126  		addr := make([]net.Address, 0)
   127  		for _, ip := range ipRec.IP {
   128  			if len(ip.IP()) == net.IPv6len {
   129  				addr = append(addr, ip)
   130  			}
   131  		}
   132  		ipRec.IP = addr
   133  		if isNewer(rec.AAAA, ipRec) {
   134  			rec.AAAA = ipRec
   135  			updated = true
   136  		}
   137  	}
   138  	newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
   139  
   140  	if updated {
   141  		s.ips[req.domain] = rec
   142  	}
   143  	switch req.reqType {
   144  	case dnsmessage.TypeA:
   145  		s.pub.Publish(req.domain+"4", nil)
   146  	case dnsmessage.TypeAAAA:
   147  		s.pub.Publish(req.domain+"6", nil)
   148  	}
   149  	s.Unlock()
   150  	common.Must(s.cleanup.Start())
   151  }
   152  
   153  func (s *QUICNameServer) newReqID() uint16 {
   154  	return uint16(atomic.AddUint32(&s.reqID, 1))
   155  }
   156  
   157  func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
   158  	newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx))
   159  
   160  	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))
   161  
   162  	var deadline time.Time
   163  	if d, ok := ctx.Deadline(); ok {
   164  		deadline = d
   165  	} else {
   166  		deadline = time.Now().Add(time.Second * 5)
   167  	}
   168  
   169  	for _, req := range reqs {
   170  		go func(r *dnsRequest) {
   171  			// generate new context for each req, using same context
   172  			// may cause reqs all aborted if any one encounter an error
   173  			dnsCtx := ctx
   174  
   175  			// reserve internal dns server requested Inbound
   176  			if inbound := session.InboundFromContext(ctx); inbound != nil {
   177  				dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
   178  			}
   179  
   180  			dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
   181  				Protocol:       "quic",
   182  				SkipDNSResolve: true,
   183  			})
   184  
   185  			var cancel context.CancelFunc
   186  			dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline)
   187  			defer cancel()
   188  
   189  			b, err := dns.PackMessage(r.msg)
   190  			if err != nil {
   191  				newError("failed to pack dns query").Base(err).AtError().WriteToLog()
   192  				return
   193  			}
   194  
   195  			conn, err := s.openStream(dnsCtx)
   196  			if err != nil {
   197  				newError("failed to open quic connection").Base(err).AtError().WriteToLog()
   198  				return
   199  			}
   200  
   201  			_, err = conn.Write(b.Bytes())
   202  			if err != nil {
   203  				newError("failed to send query").Base(err).AtError().WriteToLog()
   204  				return
   205  			}
   206  
   207  			_ = conn.Close()
   208  
   209  			respBuf := buf.New()
   210  			defer respBuf.Release()
   211  			n, err := respBuf.ReadFrom(conn)
   212  			if err != nil && n == 0 {
   213  				newError("failed to read response").Base(err).AtError().WriteToLog()
   214  				return
   215  			}
   216  
   217  			rec, err := parseResponse(respBuf.Bytes())
   218  			if err != nil {
   219  				newError("failed to handle response").Base(err).AtError().WriteToLog()
   220  				return
   221  			}
   222  			s.updateIP(r, rec)
   223  		}(req)
   224  	}
   225  }
   226  
   227  func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) {
   228  	s.RLock()
   229  	record, found := s.ips[domain]
   230  	s.RUnlock()
   231  
   232  	if !found {
   233  		return nil, errRecordNotFound
   234  	}
   235  
   236  	var err4 error
   237  	var err6 error
   238  	var ips []net.Address
   239  	var ip6 []net.Address
   240  
   241  	if option.IPv4Enable {
   242  		ips, err4 = record.A.getIPs()
   243  	}
   244  
   245  	if option.IPv6Enable {
   246  		ip6, err6 = record.AAAA.getIPs()
   247  		ips = append(ips, ip6...)
   248  	}
   249  
   250  	if len(ips) > 0 {
   251  		return toNetIP(ips)
   252  	}
   253  
   254  	if err4 != nil {
   255  		return nil, err4
   256  	}
   257  
   258  	if err6 != nil {
   259  		return nil, err6
   260  	}
   261  
   262  	if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
   263  		return nil, dns_feature.ErrEmptyResponse
   264  	}
   265  
   266  	return nil, errRecordNotFound
   267  }
   268  
   269  // QueryIP is called from dns.Server->queryIPTimeout
   270  func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) {
   271  	fqdn := Fqdn(domain)
   272  
   273  	if disableCache {
   274  		newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
   275  	} else {
   276  		ips, err := s.findIPsForDomain(fqdn, option)
   277  		if err != errRecordNotFound {
   278  			newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
   279  			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
   280  			return ips, err
   281  		}
   282  	}
   283  
   284  	// ipv4 and ipv6 belong to different subscription groups
   285  	var sub4, sub6 *pubsub.Subscriber
   286  	if option.IPv4Enable {
   287  		sub4 = s.pub.Subscribe(fqdn + "4")
   288  		defer sub4.Close()
   289  	}
   290  	if option.IPv6Enable {
   291  		sub6 = s.pub.Subscribe(fqdn + "6")
   292  		defer sub6.Close()
   293  	}
   294  	done := make(chan interface{})
   295  	go func() {
   296  		if sub4 != nil {
   297  			select {
   298  			case <-sub4.Wait():
   299  			case <-ctx.Done():
   300  			}
   301  		}
   302  		if sub6 != nil {
   303  			select {
   304  			case <-sub6.Wait():
   305  			case <-ctx.Done():
   306  			}
   307  		}
   308  		close(done)
   309  	}()
   310  	s.sendQuery(ctx, fqdn, clientIP, option)
   311  	start := time.Now()
   312  
   313  	for {
   314  		ips, err := s.findIPsForDomain(fqdn, option)
   315  		if err != errRecordNotFound {
   316  			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
   317  			return ips, err
   318  		}
   319  
   320  		select {
   321  		case <-ctx.Done():
   322  			return nil, ctx.Err()
   323  		case <-done:
   324  		}
   325  	}
   326  }
   327  
   328  func isActive(s quic.Connection) bool {
   329  	select {
   330  	case <-s.Context().Done():
   331  		return false
   332  	default:
   333  		return true
   334  	}
   335  }
   336  
   337  func (s *QUICNameServer) getConnection() (quic.Connection, error) {
   338  	var conn quic.Connection
   339  	s.RLock()
   340  	conn = s.connection
   341  	if conn != nil && isActive(conn) {
   342  		s.RUnlock()
   343  		return conn, nil
   344  	}
   345  	if conn != nil {
   346  		// we're recreating the connection, let's create a new one
   347  		_ = conn.CloseWithError(0, "")
   348  	}
   349  	s.RUnlock()
   350  
   351  	s.Lock()
   352  	defer s.Unlock()
   353  
   354  	var err error
   355  	conn, err = s.openConnection()
   356  	if err != nil {
   357  		// This does not look too nice, but QUIC (or maybe quic-go)
   358  		// doesn't seem stable enough.
   359  		// Maybe retransmissions aren't fully implemented in quic-go?
   360  		// Anyways, the simple solution is to make a second try when
   361  		// it fails to open the QUIC connection.
   362  		conn, err = s.openConnection()
   363  		if err != nil {
   364  			return nil, err
   365  		}
   366  	}
   367  	s.connection = conn
   368  	return conn, nil
   369  }
   370  
   371  func (s *QUICNameServer) openConnection() (quic.Connection, error) {
   372  	tlsConfig := tls.Config{}
   373  	quicConfig := &quic.Config{
   374  		HandshakeIdleTimeout: handshakeTimeout,
   375  	}
   376  
   377  	conn, err := quic.DialAddrContext(context.Background(), s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto("http/1.1", http2.NextProtoTLS, NextProtoDQ)), quicConfig)
   378  	log.Record(&log.AccessMessage{
   379  		From:   "DNS",
   380  		To:     s.destination,
   381  		Status: log.AccessAccepted,
   382  		Detour: "local",
   383  	})
   384  	if err != nil {
   385  		return nil, err
   386  	}
   387  
   388  	return conn, nil
   389  }
   390  
   391  func (s *QUICNameServer) openStream(ctx context.Context) (quic.Stream, error) {
   392  	conn, err := s.getConnection()
   393  	if err != nil {
   394  		return nil, err
   395  	}
   396  
   397  	// open a new stream
   398  	return conn.OpenStreamSync(ctx)
   399  }