github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/proxy/dns/dns.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/xmplusdev/xmcore/common"
    10  	"github.com/xmplusdev/xmcore/common/buf"
    11  	"github.com/xmplusdev/xmcore/common/errors"
    12  	"github.com/xmplusdev/xmcore/common/net"
    13  	dns_proto "github.com/xmplusdev/xmcore/common/protocol/dns"
    14  	"github.com/xmplusdev/xmcore/common/session"
    15  	"github.com/xmplusdev/xmcore/common/signal"
    16  	"github.com/xmplusdev/xmcore/common/task"
    17  	"github.com/xmplusdev/xmcore/core"
    18  	"github.com/xmplusdev/xmcore/features/dns"
    19  	"github.com/xmplusdev/xmcore/features/policy"
    20  	"github.com/xmplusdev/xmcore/transport"
    21  	"github.com/xmplusdev/xmcore/transport/internet"
    22  	"github.com/xmplusdev/xmcore/transport/internet/stat"
    23  	"golang.org/x/net/dns/dnsmessage"
    24  )
    25  
    26  func init() {
    27  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    28  		h := new(Handler)
    29  		if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error {
    30  			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
    31  				h.fdns = fdns
    32  			})
    33  			return h.Init(config.(*Config), dnsClient, policyManager)
    34  		}); err != nil {
    35  			return nil, err
    36  		}
    37  		return h, nil
    38  	}))
    39  }
    40  
    41  type ownLinkVerifier interface {
    42  	IsOwnLink(ctx context.Context) bool
    43  }
    44  
    45  type Handler struct {
    46  	client          dns.Client
    47  	fdns            dns.FakeDNSEngine
    48  	ownLinkVerifier ownLinkVerifier
    49  	server          net.Destination
    50  	timeout         time.Duration
    51  	nonIPQuery      string
    52  }
    53  
    54  func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error {
    55  	h.client = dnsClient
    56  	h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle
    57  
    58  	if v, ok := dnsClient.(ownLinkVerifier); ok {
    59  		h.ownLinkVerifier = v
    60  	}
    61  
    62  	if config.Server != nil {
    63  		h.server = config.Server.AsDestination()
    64  	}
    65  	h.nonIPQuery = config.Non_IPQuery
    66  	return nil
    67  }
    68  
    69  func (h *Handler) isOwnLink(ctx context.Context) bool {
    70  	return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx)
    71  }
    72  
    73  func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) {
    74  	var parser dnsmessage.Parser
    75  	header, err := parser.Start(b)
    76  	if err != nil {
    77  		newError("parser start").Base(err).WriteToLog()
    78  		return
    79  	}
    80  
    81  	id = header.ID
    82  	q, err := parser.Question()
    83  	if err != nil {
    84  		newError("question").Base(err).WriteToLog()
    85  		return
    86  	}
    87  	qType = q.Type
    88  	if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA {
    89  		return
    90  	}
    91  
    92  	domain = q.Name.String()
    93  	r = true
    94  	return
    95  }
    96  
    97  // Process implements proxy.Outbound.
    98  func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error {
    99  	outbound := session.OutboundFromContext(ctx)
   100  	if outbound == nil || !outbound.Target.IsValid() {
   101  		return newError("invalid outbound")
   102  	}
   103  	outbound.Name = "dns"
   104  
   105  	srcNetwork := outbound.Target.Network
   106  
   107  	dest := outbound.Target
   108  	if h.server.Network != net.Network_Unknown {
   109  		dest.Network = h.server.Network
   110  	}
   111  	if h.server.Address != nil {
   112  		dest.Address = h.server.Address
   113  	}
   114  	if h.server.Port != 0 {
   115  		dest.Port = h.server.Port
   116  	}
   117  
   118  	newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx))
   119  
   120  	conn := &outboundConn{
   121  		dialer: func() (stat.Connection, error) {
   122  			return d.Dial(ctx, dest)
   123  		},
   124  		connReady: make(chan struct{}, 1),
   125  	}
   126  
   127  	var reader dns_proto.MessageReader
   128  	var writer dns_proto.MessageWriter
   129  	if srcNetwork == net.Network_TCP {
   130  		reader = dns_proto.NewTCPReader(link.Reader)
   131  		writer = &dns_proto.TCPWriter{
   132  			Writer: link.Writer,
   133  		}
   134  	} else {
   135  		reader = &dns_proto.UDPReader{
   136  			Reader: link.Reader,
   137  		}
   138  		writer = &dns_proto.UDPWriter{
   139  			Writer: link.Writer,
   140  		}
   141  	}
   142  
   143  	var connReader dns_proto.MessageReader
   144  	var connWriter dns_proto.MessageWriter
   145  	if dest.Network == net.Network_TCP {
   146  		connReader = dns_proto.NewTCPReader(buf.NewReader(conn))
   147  		connWriter = &dns_proto.TCPWriter{
   148  			Writer: buf.NewWriter(conn),
   149  		}
   150  	} else {
   151  		connReader = &dns_proto.UDPReader{
   152  			Reader: buf.NewPacketReader(conn),
   153  		}
   154  		connWriter = &dns_proto.UDPWriter{
   155  			Writer: buf.NewWriter(conn),
   156  		}
   157  	}
   158  
   159  	if session.TimeoutOnlyFromContext(ctx) {
   160  		ctx, _ = context.WithCancel(context.Background())
   161  	}
   162  
   163  	ctx, cancel := context.WithCancel(ctx)
   164  	timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)
   165  
   166  	request := func() error {
   167  		defer conn.Close()
   168  
   169  		for {
   170  			b, err := reader.ReadMessage()
   171  			if err == io.EOF {
   172  				return nil
   173  			}
   174  
   175  			if err != nil {
   176  				return err
   177  			}
   178  
   179  			timer.Update()
   180  
   181  			if !h.isOwnLink(ctx) {
   182  				isIPQuery, domain, id, qType := parseIPQuery(b.Bytes())
   183  				if isIPQuery {
   184  					go h.handleIPQuery(id, qType, domain, writer)
   185  				}
   186  				if isIPQuery || h.nonIPQuery == "drop" || qType == 65 {
   187  					b.Release()
   188  					continue
   189  				}
   190  			}
   191  
   192  			if err := connWriter.WriteMessage(b); err != nil {
   193  				return err
   194  			}
   195  		}
   196  	}
   197  
   198  	response := func() error {
   199  		for {
   200  			b, err := connReader.ReadMessage()
   201  			if err == io.EOF {
   202  				return nil
   203  			}
   204  
   205  			if err != nil {
   206  				return err
   207  			}
   208  
   209  			timer.Update()
   210  
   211  			if err := writer.WriteMessage(b); err != nil {
   212  				return err
   213  			}
   214  		}
   215  	}
   216  
   217  	if err := task.Run(ctx, request, response); err != nil {
   218  		return newError("connection ends").Base(err)
   219  	}
   220  
   221  	return nil
   222  }
   223  
   224  func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
   225  	var ips []net.IP
   226  	var err error
   227  
   228  	var ttl uint32 = 600
   229  
   230  	switch qType {
   231  	case dnsmessage.TypeA:
   232  		ips, err = h.client.LookupIP(domain, dns.IPOption{
   233  			IPv4Enable: true,
   234  			IPv6Enable: false,
   235  			FakeEnable: true,
   236  		})
   237  	case dnsmessage.TypeAAAA:
   238  		ips, err = h.client.LookupIP(domain, dns.IPOption{
   239  			IPv4Enable: false,
   240  			IPv6Enable: true,
   241  			FakeEnable: true,
   242  		})
   243  	}
   244  
   245  	rcode := dns.RCodeFromError(err)
   246  	if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) {
   247  		newError("ip query").Base(err).WriteToLog()
   248  		return
   249  	}
   250  
   251  	if fkr0, ok := h.fdns.(dns.FakeDNSEngineRev0); ok && len(ips) > 0 && fkr0.IsIPInIPPool(net.IPAddress(ips[0])) {
   252  		ttl = 1
   253  	}
   254  
   255  	switch qType {
   256  	case dnsmessage.TypeA:
   257  		for i, ip := range ips {
   258  			ips[i] = ip.To4()
   259  		}
   260  	case dnsmessage.TypeAAAA:
   261  		for i, ip := range ips {
   262  			ips[i] = ip.To16()
   263  		}
   264  	}
   265  
   266  	b := buf.New()
   267  	rawBytes := b.Extend(buf.Size)
   268  	builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
   269  		ID:                 id,
   270  		RCode:              dnsmessage.RCode(rcode),
   271  		RecursionAvailable: true,
   272  		RecursionDesired:   true,
   273  		Response:           true,
   274  		Authoritative:      true,
   275  	})
   276  	builder.EnableCompression()
   277  	common.Must(builder.StartQuestions())
   278  	common.Must(builder.Question(dnsmessage.Question{
   279  		Name:  dnsmessage.MustNewName(domain),
   280  		Class: dnsmessage.ClassINET,
   281  		Type:  qType,
   282  	}))
   283  	common.Must(builder.StartAnswers())
   284  
   285  	rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl}
   286  	for _, ip := range ips {
   287  		if len(ip) == net.IPv4len {
   288  			var r dnsmessage.AResource
   289  			copy(r.A[:], ip)
   290  			common.Must(builder.AResource(rHeader, r))
   291  		} else {
   292  			var r dnsmessage.AAAAResource
   293  			copy(r.AAAA[:], ip)
   294  			common.Must(builder.AAAAResource(rHeader, r))
   295  		}
   296  	}
   297  	msgBytes, err := builder.Finish()
   298  	if err != nil {
   299  		newError("pack message").Base(err).WriteToLog()
   300  		b.Release()
   301  		return
   302  	}
   303  	b.Resize(0, int32(len(msgBytes)))
   304  
   305  	if err := writer.WriteMessage(b); err != nil {
   306  		newError("write IP answer").Base(err).WriteToLog()
   307  	}
   308  }
   309  
   310  type outboundConn struct {
   311  	access sync.Mutex
   312  	dialer func() (stat.Connection, error)
   313  
   314  	conn      net.Conn
   315  	connReady chan struct{}
   316  }
   317  
   318  func (c *outboundConn) dial() error {
   319  	conn, err := c.dialer()
   320  	if err != nil {
   321  		return err
   322  	}
   323  	c.conn = conn
   324  	c.connReady <- struct{}{}
   325  	return nil
   326  }
   327  
   328  func (c *outboundConn) Write(b []byte) (int, error) {
   329  	c.access.Lock()
   330  
   331  	if c.conn == nil {
   332  		if err := c.dial(); err != nil {
   333  			c.access.Unlock()
   334  			newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog()
   335  			return len(b), nil
   336  		}
   337  	}
   338  
   339  	c.access.Unlock()
   340  
   341  	return c.conn.Write(b)
   342  }
   343  
   344  func (c *outboundConn) Read(b []byte) (int, error) {
   345  	var conn net.Conn
   346  	c.access.Lock()
   347  	conn = c.conn
   348  	c.access.Unlock()
   349  
   350  	if conn == nil {
   351  		_, open := <-c.connReady
   352  		if !open {
   353  			return 0, io.EOF
   354  		}
   355  		conn = c.conn
   356  	}
   357  
   358  	return conn.Read(b)
   359  }
   360  
   361  func (c *outboundConn) Close() error {
   362  	c.access.Lock()
   363  	close(c.connReady)
   364  	if c.conn != nil {
   365  		c.conn.Close()
   366  	}
   367  	c.access.Unlock()
   368  	return nil
   369  }