github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/dns/dns.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/xtls/xray-core/common"
    10  	"github.com/xtls/xray-core/common/buf"
    11  	"github.com/xtls/xray-core/common/errors"
    12  	"github.com/xtls/xray-core/common/net"
    13  	dns_proto "github.com/xtls/xray-core/common/protocol/dns"
    14  	"github.com/xtls/xray-core/common/session"
    15  	"github.com/xtls/xray-core/common/signal"
    16  	"github.com/xtls/xray-core/common/task"
    17  	"github.com/xtls/xray-core/core"
    18  	"github.com/xtls/xray-core/features/dns"
    19  	"github.com/xtls/xray-core/features/policy"
    20  	"github.com/xtls/xray-core/transport"
    21  	"github.com/xtls/xray-core/transport/internet"
    22  	"github.com/xtls/xray-core/transport/internet/stat"
    23  	"golang.org/x/net/dns/dnsmessage"
    24  )
    25  
    26  func init() {
    27  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    28  		h := new(Handler)
    29  		if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error {
    30  			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
    31  				h.fdns = fdns
    32  			})
    33  			return h.Init(config.(*Config), dnsClient, policyManager)
    34  		}); err != nil {
    35  			return nil, err
    36  		}
    37  		return h, nil
    38  	}))
    39  }
    40  
    41  type ownLinkVerifier interface {
    42  	IsOwnLink(ctx context.Context) bool
    43  }
    44  
    45  type Handler struct {
    46  	client          dns.Client
    47  	fdns            dns.FakeDNSEngine
    48  	ownLinkVerifier ownLinkVerifier
    49  	server          net.Destination
    50  	timeout         time.Duration
    51  	nonIPQuery      string
    52  }
    53  
    54  func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error {
    55  	h.client = dnsClient
    56  	h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle
    57  
    58  	if v, ok := dnsClient.(ownLinkVerifier); ok {
    59  		h.ownLinkVerifier = v
    60  	}
    61  
    62  	if config.Server != nil {
    63  		h.server = config.Server.AsDestination()
    64  	}
    65  	h.nonIPQuery = config.Non_IPQuery
    66  	return nil
    67  }
    68  
    69  func (h *Handler) isOwnLink(ctx context.Context) bool {
    70  	return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx)
    71  }
    72  
    73  func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) {
    74  	var parser dnsmessage.Parser
    75  	header, err := parser.Start(b)
    76  	if err != nil {
    77  		newError("parser start").Base(err).WriteToLog()
    78  		return
    79  	}
    80  
    81  	id = header.ID
    82  	q, err := parser.Question()
    83  	if err != nil {
    84  		newError("question").Base(err).WriteToLog()
    85  		return
    86  	}
    87  	qType = q.Type
    88  	if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA {
    89  		return
    90  	}
    91  
    92  	domain = q.Name.String()
    93  	r = true
    94  	return
    95  }
    96  
    97  // Process implements proxy.Outbound.
    98  func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error {
    99  	outbounds := session.OutboundsFromContext(ctx)
   100  	ob := outbounds[len(outbounds) - 1]
   101  	if !ob.Target.IsValid() {
   102  		return newError("invalid outbound")
   103  	}
   104  	ob.Name = "dns"
   105  
   106  	srcNetwork := ob.Target.Network
   107  
   108  	dest := ob.Target
   109  	if h.server.Network != net.Network_Unknown {
   110  		dest.Network = h.server.Network
   111  	}
   112  	if h.server.Address != nil {
   113  		dest.Address = h.server.Address
   114  	}
   115  	if h.server.Port != 0 {
   116  		dest.Port = h.server.Port
   117  	}
   118  
   119  	newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx))
   120  
   121  	conn := &outboundConn{
   122  		dialer: func() (stat.Connection, error) {
   123  			return d.Dial(ctx, dest)
   124  		},
   125  		connReady: make(chan struct{}, 1),
   126  	}
   127  
   128  	var reader dns_proto.MessageReader
   129  	var writer dns_proto.MessageWriter
   130  	if srcNetwork == net.Network_TCP {
   131  		reader = dns_proto.NewTCPReader(link.Reader)
   132  		writer = &dns_proto.TCPWriter{
   133  			Writer: link.Writer,
   134  		}
   135  	} else {
   136  		reader = &dns_proto.UDPReader{
   137  			Reader: link.Reader,
   138  		}
   139  		writer = &dns_proto.UDPWriter{
   140  			Writer: link.Writer,
   141  		}
   142  	}
   143  
   144  	var connReader dns_proto.MessageReader
   145  	var connWriter dns_proto.MessageWriter
   146  	if dest.Network == net.Network_TCP {
   147  		connReader = dns_proto.NewTCPReader(buf.NewReader(conn))
   148  		connWriter = &dns_proto.TCPWriter{
   149  			Writer: buf.NewWriter(conn),
   150  		}
   151  	} else {
   152  		connReader = &dns_proto.UDPReader{
   153  			Reader: buf.NewPacketReader(conn),
   154  		}
   155  		connWriter = &dns_proto.UDPWriter{
   156  			Writer: buf.NewWriter(conn),
   157  		}
   158  	}
   159  
   160  	if session.TimeoutOnlyFromContext(ctx) {
   161  		ctx, _ = context.WithCancel(context.Background())
   162  	}
   163  
   164  	ctx, cancel := context.WithCancel(ctx)
   165  	timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)
   166  
   167  	request := func() error {
   168  		defer conn.Close()
   169  
   170  		for {
   171  			b, err := reader.ReadMessage()
   172  			if err == io.EOF {
   173  				return nil
   174  			}
   175  
   176  			if err != nil {
   177  				return err
   178  			}
   179  
   180  			timer.Update()
   181  
   182  			if !h.isOwnLink(ctx) {
   183  				isIPQuery, domain, id, qType := parseIPQuery(b.Bytes())
   184  				if isIPQuery {
   185  					go h.handleIPQuery(id, qType, domain, writer)
   186  				}
   187  				if isIPQuery || h.nonIPQuery == "drop" || qType == 65 {
   188  					b.Release()
   189  					continue
   190  				}
   191  			}
   192  
   193  			if err := connWriter.WriteMessage(b); err != nil {
   194  				return err
   195  			}
   196  		}
   197  	}
   198  
   199  	response := func() error {
   200  		for {
   201  			b, err := connReader.ReadMessage()
   202  			if err == io.EOF {
   203  				return nil
   204  			}
   205  
   206  			if err != nil {
   207  				return err
   208  			}
   209  
   210  			timer.Update()
   211  
   212  			if err := writer.WriteMessage(b); err != nil {
   213  				return err
   214  			}
   215  		}
   216  	}
   217  
   218  	if err := task.Run(ctx, request, response); err != nil {
   219  		return newError("connection ends").Base(err)
   220  	}
   221  
   222  	return nil
   223  }
   224  
   225  func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
   226  	var ips []net.IP
   227  	var err error
   228  
   229  	var ttl uint32 = 600
   230  
   231  	switch qType {
   232  	case dnsmessage.TypeA:
   233  		ips, err = h.client.LookupIP(domain, dns.IPOption{
   234  			IPv4Enable: true,
   235  			IPv6Enable: false,
   236  			FakeEnable: true,
   237  		})
   238  	case dnsmessage.TypeAAAA:
   239  		ips, err = h.client.LookupIP(domain, dns.IPOption{
   240  			IPv4Enable: false,
   241  			IPv6Enable: true,
   242  			FakeEnable: true,
   243  		})
   244  	}
   245  
   246  	rcode := dns.RCodeFromError(err)
   247  	if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) {
   248  		newError("ip query").Base(err).WriteToLog()
   249  		return
   250  	}
   251  
   252  	if fkr0, ok := h.fdns.(dns.FakeDNSEngineRev0); ok && len(ips) > 0 && fkr0.IsIPInIPPool(net.IPAddress(ips[0])) {
   253  		ttl = 1
   254  	}
   255  
   256  	switch qType {
   257  	case dnsmessage.TypeA:
   258  		for i, ip := range ips {
   259  			ips[i] = ip.To4()
   260  		}
   261  	case dnsmessage.TypeAAAA:
   262  		for i, ip := range ips {
   263  			ips[i] = ip.To16()
   264  		}
   265  	}
   266  
   267  	b := buf.New()
   268  	rawBytes := b.Extend(buf.Size)
   269  	builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
   270  		ID:                 id,
   271  		RCode:              dnsmessage.RCode(rcode),
   272  		RecursionAvailable: true,
   273  		RecursionDesired:   true,
   274  		Response:           true,
   275  		Authoritative:      true,
   276  	})
   277  	builder.EnableCompression()
   278  	common.Must(builder.StartQuestions())
   279  	common.Must(builder.Question(dnsmessage.Question{
   280  		Name:  dnsmessage.MustNewName(domain),
   281  		Class: dnsmessage.ClassINET,
   282  		Type:  qType,
   283  	}))
   284  	common.Must(builder.StartAnswers())
   285  
   286  	rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl}
   287  	for _, ip := range ips {
   288  		if len(ip) == net.IPv4len {
   289  			var r dnsmessage.AResource
   290  			copy(r.A[:], ip)
   291  			common.Must(builder.AResource(rHeader, r))
   292  		} else {
   293  			var r dnsmessage.AAAAResource
   294  			copy(r.AAAA[:], ip)
   295  			common.Must(builder.AAAAResource(rHeader, r))
   296  		}
   297  	}
   298  	msgBytes, err := builder.Finish()
   299  	if err != nil {
   300  		newError("pack message").Base(err).WriteToLog()
   301  		b.Release()
   302  		return
   303  	}
   304  	b.Resize(0, int32(len(msgBytes)))
   305  
   306  	if err := writer.WriteMessage(b); err != nil {
   307  		newError("write IP answer").Base(err).WriteToLog()
   308  	}
   309  }
   310  
   311  type outboundConn struct {
   312  	access sync.Mutex
   313  	dialer func() (stat.Connection, error)
   314  
   315  	conn      net.Conn
   316  	connReady chan struct{}
   317  }
   318  
   319  func (c *outboundConn) dial() error {
   320  	conn, err := c.dialer()
   321  	if err != nil {
   322  		return err
   323  	}
   324  	c.conn = conn
   325  	c.connReady <- struct{}{}
   326  	return nil
   327  }
   328  
   329  func (c *outboundConn) Write(b []byte) (int, error) {
   330  	c.access.Lock()
   331  
   332  	if c.conn == nil {
   333  		if err := c.dial(); err != nil {
   334  			c.access.Unlock()
   335  			newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog()
   336  			return len(b), nil
   337  		}
   338  	}
   339  
   340  	c.access.Unlock()
   341  
   342  	return c.conn.Write(b)
   343  }
   344  
   345  func (c *outboundConn) Read(b []byte) (int, error) {
   346  	var conn net.Conn
   347  	c.access.Lock()
   348  	conn = c.conn
   349  	c.access.Unlock()
   350  
   351  	if conn == nil {
   352  		_, open := <-c.connReady
   353  		if !open {
   354  			return 0, io.EOF
   355  		}
   356  		conn = c.conn
   357  	}
   358  
   359  	return conn.Read(b)
   360  }
   361  
   362  func (c *outboundConn) Close() error {
   363  	c.access.Lock()
   364  	close(c.connReady)
   365  	if c.conn != nil {
   366  		c.conn.Close()
   367  	}
   368  	c.access.Unlock()
   369  	return nil
   370  }