github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/app/dns/nameserver_tcp.go (about)

     1  package dns
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/binary"
     7  	"net/url"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/xmplusdev/xmcore/common"
    13  	"github.com/xmplusdev/xmcore/common/buf"
    14  	"github.com/xmplusdev/xmcore/common/log"
    15  	"github.com/xmplusdev/xmcore/common/net"
    16  	"github.com/xmplusdev/xmcore/common/net/cnc"
    17  	"github.com/xmplusdev/xmcore/common/protocol/dns"
    18  	"github.com/xmplusdev/xmcore/common/session"
    19  	"github.com/xmplusdev/xmcore/common/signal/pubsub"
    20  	"github.com/xmplusdev/xmcore/common/task"
    21  	dns_feature "github.com/xmplusdev/xmcore/features/dns"
    22  	"github.com/xmplusdev/xmcore/features/routing"
    23  	"github.com/xmplusdev/xmcore/transport/internet"
    24  	"golang.org/x/net/dns/dnsmessage"
    25  )
    26  
    27  // TCPNameServer implemented DNS over TCP (RFC7766).
    28  type TCPNameServer struct {
    29  	sync.RWMutex
    30  	name          string
    31  	destination   *net.Destination
    32  	ips           map[string]*record
    33  	pub           *pubsub.Service
    34  	cleanup       *task.Periodic
    35  	reqID         uint32
    36  	dial          func(context.Context) (net.Conn, error)
    37  	queryStrategy QueryStrategy
    38  }
    39  
    40  // NewTCPNameServer creates DNS over TCP server object for remote resolving.
    41  func NewTCPNameServer(
    42  	url *url.URL,
    43  	dispatcher routing.Dispatcher,
    44  	queryStrategy QueryStrategy,
    45  ) (*TCPNameServer, error) {
    46  	s, err := baseTCPNameServer(url, "TCP", queryStrategy)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	s.dial = func(ctx context.Context) (net.Conn, error) {
    52  		link, err := dispatcher.Dispatch(toDnsContext(ctx, s.destination.String()), *s.destination)
    53  		if err != nil {
    54  			return nil, err
    55  		}
    56  
    57  		return cnc.NewConnection(
    58  			cnc.ConnectionInputMulti(link.Writer),
    59  			cnc.ConnectionOutputMulti(link.Reader),
    60  		), nil
    61  	}
    62  
    63  	return s, nil
    64  }
    65  
    66  // NewTCPLocalNameServer creates DNS over TCP client object for local resolving
    67  func NewTCPLocalNameServer(url *url.URL, queryStrategy QueryStrategy) (*TCPNameServer, error) {
    68  	s, err := baseTCPNameServer(url, "TCPL", queryStrategy)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	s.dial = func(ctx context.Context) (net.Conn, error) {
    74  		return internet.DialSystem(ctx, *s.destination, nil)
    75  	}
    76  
    77  	return s, nil
    78  }
    79  
    80  func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy) (*TCPNameServer, error) {
    81  	port := net.Port(53)
    82  	if url.Port() != "" {
    83  		var err error
    84  		if port, err = net.PortFromString(url.Port()); err != nil {
    85  			return nil, err
    86  		}
    87  	}
    88  	dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port)
    89  
    90  	s := &TCPNameServer{
    91  		destination:   &dest,
    92  		ips:           make(map[string]*record),
    93  		pub:           pubsub.NewService(),
    94  		name:          prefix + "//" + dest.NetAddr(),
    95  		queryStrategy: queryStrategy,
    96  	}
    97  	s.cleanup = &task.Periodic{
    98  		Interval: time.Minute,
    99  		Execute:  s.Cleanup,
   100  	}
   101  
   102  	return s, nil
   103  }
   104  
   105  // Name implements Server.
   106  func (s *TCPNameServer) Name() string {
   107  	return s.name
   108  }
   109  
   110  // Cleanup clears expired items from cache
   111  func (s *TCPNameServer) Cleanup() error {
   112  	now := time.Now()
   113  	s.Lock()
   114  	defer s.Unlock()
   115  
   116  	if len(s.ips) == 0 {
   117  		return newError("nothing to do. stopping...")
   118  	}
   119  
   120  	for domain, record := range s.ips {
   121  		if record.A != nil && record.A.Expire.Before(now) {
   122  			record.A = nil
   123  		}
   124  		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
   125  			record.AAAA = nil
   126  		}
   127  
   128  		if record.A == nil && record.AAAA == nil {
   129  			newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
   130  			delete(s.ips, domain)
   131  		} else {
   132  			s.ips[domain] = record
   133  		}
   134  	}
   135  
   136  	if len(s.ips) == 0 {
   137  		s.ips = make(map[string]*record)
   138  	}
   139  
   140  	return nil
   141  }
   142  
   143  func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
   144  	elapsed := time.Since(req.start)
   145  
   146  	s.Lock()
   147  	rec, found := s.ips[req.domain]
   148  	if !found {
   149  		rec = &record{}
   150  	}
   151  	updated := false
   152  
   153  	switch req.reqType {
   154  	case dnsmessage.TypeA:
   155  		if isNewer(rec.A, ipRec) {
   156  			rec.A = ipRec
   157  			updated = true
   158  		}
   159  	case dnsmessage.TypeAAAA:
   160  		addr := make([]net.Address, 0)
   161  		for _, ip := range ipRec.IP {
   162  			if len(ip.IP()) == net.IPv6len {
   163  				addr = append(addr, ip)
   164  			}
   165  		}
   166  		ipRec.IP = addr
   167  		if isNewer(rec.AAAA, ipRec) {
   168  			rec.AAAA = ipRec
   169  			updated = true
   170  		}
   171  	}
   172  	newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
   173  
   174  	if updated {
   175  		s.ips[req.domain] = rec
   176  	}
   177  	switch req.reqType {
   178  	case dnsmessage.TypeA:
   179  		s.pub.Publish(req.domain+"4", nil)
   180  	case dnsmessage.TypeAAAA:
   181  		s.pub.Publish(req.domain+"6", nil)
   182  	}
   183  	s.Unlock()
   184  	common.Must(s.cleanup.Start())
   185  }
   186  
   187  func (s *TCPNameServer) newReqID() uint16 {
   188  	return uint16(atomic.AddUint32(&s.reqID, 1))
   189  }
   190  
   191  func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
   192  	newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
   193  
   194  	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))
   195  
   196  	var deadline time.Time
   197  	if d, ok := ctx.Deadline(); ok {
   198  		deadline = d
   199  	} else {
   200  		deadline = time.Now().Add(time.Second * 5)
   201  	}
   202  
   203  	for _, req := range reqs {
   204  		go func(r *dnsRequest) {
   205  			dnsCtx := ctx
   206  
   207  			if inbound := session.InboundFromContext(ctx); inbound != nil {
   208  				dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
   209  			}
   210  
   211  			dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
   212  				Protocol:       "dns",
   213  				SkipDNSResolve: true,
   214  			})
   215  
   216  			var cancel context.CancelFunc
   217  			dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline)
   218  			defer cancel()
   219  
   220  			b, err := dns.PackMessage(r.msg)
   221  			if err != nil {
   222  				newError("failed to pack dns query").Base(err).AtError().WriteToLog()
   223  				return
   224  			}
   225  
   226  			conn, err := s.dial(dnsCtx)
   227  			if err != nil {
   228  				newError("failed to dial namesever").Base(err).AtError().WriteToLog()
   229  				return
   230  			}
   231  			defer conn.Close()
   232  			dnsReqBuf := buf.New()
   233  			binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
   234  			dnsReqBuf.Write(b.Bytes())
   235  			b.Release()
   236  
   237  			_, err = conn.Write(dnsReqBuf.Bytes())
   238  			if err != nil {
   239  				newError("failed to send query").Base(err).AtError().WriteToLog()
   240  				return
   241  			}
   242  			dnsReqBuf.Release()
   243  
   244  			respBuf := buf.New()
   245  			defer respBuf.Release()
   246  			n, err := respBuf.ReadFullFrom(conn, 2)
   247  			if err != nil && n == 0 {
   248  				newError("failed to read response length").Base(err).AtError().WriteToLog()
   249  				return
   250  			}
   251  			var length int16
   252  			err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)
   253  			if err != nil {
   254  				newError("failed to parse response length").Base(err).AtError().WriteToLog()
   255  				return
   256  			}
   257  			respBuf.Clear()
   258  			n, err = respBuf.ReadFullFrom(conn, int32(length))
   259  			if err != nil && n == 0 {
   260  				newError("failed to read response length").Base(err).AtError().WriteToLog()
   261  				return
   262  			}
   263  
   264  			rec, err := parseResponse(respBuf.Bytes())
   265  			if err != nil {
   266  				newError("failed to parse DNS over TCP response").Base(err).AtError().WriteToLog()
   267  				return
   268  			}
   269  
   270  			s.updateIP(r, rec)
   271  		}(req)
   272  	}
   273  }
   274  
   275  func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) {
   276  	s.RLock()
   277  	record, found := s.ips[domain]
   278  	s.RUnlock()
   279  
   280  	if !found {
   281  		return nil, errRecordNotFound
   282  	}
   283  
   284  	var err4 error
   285  	var err6 error
   286  	var ips []net.Address
   287  	var ip6 []net.Address
   288  
   289  	if option.IPv4Enable {
   290  		ips, err4 = record.A.getIPs()
   291  	}
   292  
   293  	if option.IPv6Enable {
   294  		ip6, err6 = record.AAAA.getIPs()
   295  		ips = append(ips, ip6...)
   296  	}
   297  
   298  	if len(ips) > 0 {
   299  		return toNetIP(ips)
   300  	}
   301  
   302  	if err4 != nil {
   303  		return nil, err4
   304  	}
   305  
   306  	if err6 != nil {
   307  		return nil, err6
   308  	}
   309  
   310  	return nil, dns_feature.ErrEmptyResponse
   311  }
   312  
   313  // QueryIP implements Server.
   314  func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) {
   315  	fqdn := Fqdn(domain)
   316  	option = ResolveIpOptionOverride(s.queryStrategy, option)
   317  	if !option.IPv4Enable && !option.IPv6Enable {
   318  		return nil, dns_feature.ErrEmptyResponse
   319  	}
   320  
   321  	if disableCache {
   322  		newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
   323  	} else {
   324  		ips, err := s.findIPsForDomain(fqdn, option)
   325  		if err != errRecordNotFound {
   326  			newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
   327  			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
   328  			return ips, err
   329  		}
   330  	}
   331  
   332  	// ipv4 and ipv6 belong to different subscription groups
   333  	var sub4, sub6 *pubsub.Subscriber
   334  	if option.IPv4Enable {
   335  		sub4 = s.pub.Subscribe(fqdn + "4")
   336  		defer sub4.Close()
   337  	}
   338  	if option.IPv6Enable {
   339  		sub6 = s.pub.Subscribe(fqdn + "6")
   340  		defer sub6.Close()
   341  	}
   342  	done := make(chan interface{})
   343  	go func() {
   344  		if sub4 != nil {
   345  			select {
   346  			case <-sub4.Wait():
   347  			case <-ctx.Done():
   348  			}
   349  		}
   350  		if sub6 != nil {
   351  			select {
   352  			case <-sub6.Wait():
   353  			case <-ctx.Done():
   354  			}
   355  		}
   356  		close(done)
   357  	}()
   358  	s.sendQuery(ctx, fqdn, clientIP, option)
   359  	start := time.Now()
   360  
   361  	for {
   362  		ips, err := s.findIPsForDomain(fqdn, option)
   363  		if err != errRecordNotFound {
   364  			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
   365  			return ips, err
   366  		}
   367  
   368  		select {
   369  		case <-ctx.Done():
   370  			return nil, ctx.Err()
   371  		case <-done:
   372  		}
   373  	}
   374  }