github.com/moqsien/xraycore@v1.8.5/proxy/dns/dns.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/moqsien/xraycore/common"
    10  	"github.com/moqsien/xraycore/common/buf"
    11  	"github.com/moqsien/xraycore/common/errors"
    12  	"github.com/moqsien/xraycore/common/net"
    13  	dns_proto "github.com/moqsien/xraycore/common/protocol/dns"
    14  	"github.com/moqsien/xraycore/common/session"
    15  	"github.com/moqsien/xraycore/common/signal"
    16  	"github.com/moqsien/xraycore/common/task"
    17  	"github.com/moqsien/xraycore/core"
    18  	"github.com/moqsien/xraycore/features/dns"
    19  	"github.com/moqsien/xraycore/features/policy"
    20  	"github.com/moqsien/xraycore/transport"
    21  	"github.com/moqsien/xraycore/transport/internet"
    22  	"github.com/moqsien/xraycore/transport/internet/stat"
    23  	"golang.org/x/net/dns/dnsmessage"
    24  )
    25  
    26  func init() {
    27  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    28  		h := new(Handler)
    29  		if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error {
    30  			return h.Init(config.(*Config), dnsClient, policyManager)
    31  		}); err != nil {
    32  			return nil, err
    33  		}
    34  		return h, nil
    35  	}))
    36  }
    37  
    38  type ownLinkVerifier interface {
    39  	IsOwnLink(ctx context.Context) bool
    40  }
    41  
    42  type Handler struct {
    43  	client          dns.Client
    44  	ownLinkVerifier ownLinkVerifier
    45  	server          net.Destination
    46  	timeout         time.Duration
    47  	nonIPQuery      string
    48  }
    49  
    50  func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error {
    51  	h.client = dnsClient
    52  	h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle
    53  
    54  	if v, ok := dnsClient.(ownLinkVerifier); ok {
    55  		h.ownLinkVerifier = v
    56  	}
    57  
    58  	if config.Server != nil {
    59  		h.server = config.Server.AsDestination()
    60  	}
    61  	h.nonIPQuery = config.Non_IPQuery
    62  	return nil
    63  }
    64  
    65  func (h *Handler) isOwnLink(ctx context.Context) bool {
    66  	return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx)
    67  }
    68  
    69  func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) {
    70  	var parser dnsmessage.Parser
    71  	header, err := parser.Start(b)
    72  	if err != nil {
    73  		newError("parser start").Base(err).WriteToLog()
    74  		return
    75  	}
    76  
    77  	id = header.ID
    78  	q, err := parser.Question()
    79  	if err != nil {
    80  		newError("question").Base(err).WriteToLog()
    81  		return
    82  	}
    83  	qType = q.Type
    84  	if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA {
    85  		return
    86  	}
    87  
    88  	domain = q.Name.String()
    89  	r = true
    90  	return
    91  }
    92  
    93  // Process implements proxy.Outbound.
    94  func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error {
    95  	outbound := session.OutboundFromContext(ctx)
    96  	if outbound == nil || !outbound.Target.IsValid() {
    97  		return newError("invalid outbound")
    98  	}
    99  
   100  	srcNetwork := outbound.Target.Network
   101  
   102  	dest := outbound.Target
   103  	if h.server.Network != net.Network_Unknown {
   104  		dest.Network = h.server.Network
   105  	}
   106  	if h.server.Address != nil {
   107  		dest.Address = h.server.Address
   108  	}
   109  	if h.server.Port != 0 {
   110  		dest.Port = h.server.Port
   111  	}
   112  
   113  	newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx))
   114  
   115  	conn := &outboundConn{
   116  		dialer: func() (stat.Connection, error) {
   117  			return d.Dial(ctx, dest)
   118  		},
   119  		connReady: make(chan struct{}, 1),
   120  	}
   121  
   122  	var reader dns_proto.MessageReader
   123  	var writer dns_proto.MessageWriter
   124  	if srcNetwork == net.Network_TCP {
   125  		reader = dns_proto.NewTCPReader(link.Reader)
   126  		writer = &dns_proto.TCPWriter{
   127  			Writer: link.Writer,
   128  		}
   129  	} else {
   130  		reader = &dns_proto.UDPReader{
   131  			Reader: link.Reader,
   132  		}
   133  		writer = &dns_proto.UDPWriter{
   134  			Writer: link.Writer,
   135  		}
   136  	}
   137  
   138  	var connReader dns_proto.MessageReader
   139  	var connWriter dns_proto.MessageWriter
   140  	if dest.Network == net.Network_TCP {
   141  		connReader = dns_proto.NewTCPReader(buf.NewReader(conn))
   142  		connWriter = &dns_proto.TCPWriter{
   143  			Writer: buf.NewWriter(conn),
   144  		}
   145  	} else {
   146  		connReader = &dns_proto.UDPReader{
   147  			Reader: buf.NewPacketReader(conn),
   148  		}
   149  		connWriter = &dns_proto.UDPWriter{
   150  			Writer: buf.NewWriter(conn),
   151  		}
   152  	}
   153  
   154  	if session.TimeoutOnlyFromContext(ctx) {
   155  		ctx, _ = context.WithCancel(context.Background())
   156  	}
   157  
   158  	ctx, cancel := context.WithCancel(ctx)
   159  	timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)
   160  
   161  	request := func() error {
   162  		defer conn.Close()
   163  
   164  		for {
   165  			b, err := reader.ReadMessage()
   166  			if err == io.EOF {
   167  				return nil
   168  			}
   169  
   170  			if err != nil {
   171  				return err
   172  			}
   173  
   174  			timer.Update()
   175  
   176  			if !h.isOwnLink(ctx) {
   177  				isIPQuery, domain, id, qType := parseIPQuery(b.Bytes())
   178  				if isIPQuery {
   179  					go h.handleIPQuery(id, qType, domain, writer)
   180  				}
   181  				if isIPQuery || h.nonIPQuery == "drop" {
   182  					b.Release()
   183  					continue
   184  				}
   185  			}
   186  
   187  			if err := connWriter.WriteMessage(b); err != nil {
   188  				return err
   189  			}
   190  		}
   191  	}
   192  
   193  	response := func() error {
   194  		for {
   195  			b, err := connReader.ReadMessage()
   196  			if err == io.EOF {
   197  				return nil
   198  			}
   199  
   200  			if err != nil {
   201  				return err
   202  			}
   203  
   204  			timer.Update()
   205  
   206  			if err := writer.WriteMessage(b); err != nil {
   207  				return err
   208  			}
   209  		}
   210  	}
   211  
   212  	if err := task.Run(ctx, request, response); err != nil {
   213  		return newError("connection ends").Base(err)
   214  	}
   215  
   216  	return nil
   217  }
   218  
   219  func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
   220  	var ips []net.IP
   221  	var err error
   222  
   223  	var ttl uint32 = 600
   224  
   225  	switch qType {
   226  	case dnsmessage.TypeA:
   227  		ips, err = h.client.LookupIP(domain, dns.IPOption{
   228  			IPv4Enable: true,
   229  			IPv6Enable: false,
   230  			FakeEnable: true,
   231  		})
   232  	case dnsmessage.TypeAAAA:
   233  		ips, err = h.client.LookupIP(domain, dns.IPOption{
   234  			IPv4Enable: false,
   235  			IPv6Enable: true,
   236  			FakeEnable: true,
   237  		})
   238  	}
   239  
   240  	rcode := dns.RCodeFromError(err)
   241  	if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) {
   242  		newError("ip query").Base(err).WriteToLog()
   243  		return
   244  	}
   245  
   246  	switch qType {
   247  	case dnsmessage.TypeA:
   248  		for i, ip := range ips {
   249  			ips[i] = ip.To4()
   250  		}
   251  	case dnsmessage.TypeAAAA:
   252  		for i, ip := range ips {
   253  			ips[i] = ip.To16()
   254  		}
   255  	}
   256  
   257  	b := buf.New()
   258  	rawBytes := b.Extend(buf.Size)
   259  	builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
   260  		ID:                 id,
   261  		RCode:              dnsmessage.RCode(rcode),
   262  		RecursionAvailable: true,
   263  		RecursionDesired:   true,
   264  		Response:           true,
   265  		Authoritative:      true,
   266  	})
   267  	builder.EnableCompression()
   268  	common.Must(builder.StartQuestions())
   269  	common.Must(builder.Question(dnsmessage.Question{
   270  		Name:  dnsmessage.MustNewName(domain),
   271  		Class: dnsmessage.ClassINET,
   272  		Type:  qType,
   273  	}))
   274  	common.Must(builder.StartAnswers())
   275  
   276  	rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl}
   277  	for _, ip := range ips {
   278  		if len(ip) == net.IPv4len {
   279  			var r dnsmessage.AResource
   280  			copy(r.A[:], ip)
   281  			common.Must(builder.AResource(rHeader, r))
   282  		} else {
   283  			var r dnsmessage.AAAAResource
   284  			copy(r.AAAA[:], ip)
   285  			common.Must(builder.AAAAResource(rHeader, r))
   286  		}
   287  	}
   288  	msgBytes, err := builder.Finish()
   289  	if err != nil {
   290  		newError("pack message").Base(err).WriteToLog()
   291  		b.Release()
   292  		return
   293  	}
   294  	b.Resize(0, int32(len(msgBytes)))
   295  
   296  	if err := writer.WriteMessage(b); err != nil {
   297  		newError("write IP answer").Base(err).WriteToLog()
   298  	}
   299  }
   300  
   301  type outboundConn struct {
   302  	access sync.Mutex
   303  	dialer func() (stat.Connection, error)
   304  
   305  	conn      net.Conn
   306  	connReady chan struct{}
   307  }
   308  
   309  func (c *outboundConn) dial() error {
   310  	conn, err := c.dialer()
   311  	if err != nil {
   312  		return err
   313  	}
   314  	c.conn = conn
   315  	c.connReady <- struct{}{}
   316  	return nil
   317  }
   318  
   319  func (c *outboundConn) Write(b []byte) (int, error) {
   320  	c.access.Lock()
   321  
   322  	if c.conn == nil {
   323  		if err := c.dial(); err != nil {
   324  			c.access.Unlock()
   325  			newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog()
   326  			return len(b), nil
   327  		}
   328  	}
   329  
   330  	c.access.Unlock()
   331  
   332  	return c.conn.Write(b)
   333  }
   334  
   335  func (c *outboundConn) Read(b []byte) (int, error) {
   336  	var conn net.Conn
   337  	c.access.Lock()
   338  	conn = c.conn
   339  	c.access.Unlock()
   340  
   341  	if conn == nil {
   342  		_, open := <-c.connReady
   343  		if !open {
   344  			return 0, io.EOF
   345  		}
   346  		conn = c.conn
   347  	}
   348  
   349  	return conn.Read(b)
   350  }
   351  
   352  func (c *outboundConn) Close() error {
   353  	c.access.Lock()
   354  	close(c.connReady)
   355  	if c.conn != nil {
   356  		c.conn.Close()
   357  	}
   358  	c.access.Unlock()
   359  	return nil
   360  }