github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/app/dns/nameserver_doh.go (about)

     1  package dns
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/url"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/xmplusdev/xmcore/common"
    15  	"github.com/xmplusdev/xmcore/common/log"
    16  	"github.com/xmplusdev/xmcore/common/net"
    17  	"github.com/xmplusdev/xmcore/common/net/cnc"
    18  	"github.com/xmplusdev/xmcore/common/protocol/dns"
    19  	"github.com/xmplusdev/xmcore/common/session"
    20  	"github.com/xmplusdev/xmcore/common/signal/pubsub"
    21  	"github.com/xmplusdev/xmcore/common/task"
    22  	dns_feature "github.com/xmplusdev/xmcore/features/dns"
    23  	"github.com/xmplusdev/xmcore/features/routing"
    24  	"github.com/xmplusdev/xmcore/transport/internet"
    25  	"golang.org/x/net/dns/dnsmessage"
    26  )
    27  
    28  // DoHNameServer implemented DNS over HTTPS (RFC8484) Wire Format,
    29  // which is compatible with traditional dns over udp(RFC1035),
    30  // thus most of the DOH implementation is copied from udpns.go
    31  type DoHNameServer struct {
    32  	dispatcher routing.Dispatcher
    33  	sync.RWMutex
    34  	ips           map[string]*record
    35  	pub           *pubsub.Service
    36  	cleanup       *task.Periodic
    37  	reqID         uint32
    38  	httpClient    *http.Client
    39  	dohURL        string
    40  	name          string
    41  	queryStrategy QueryStrategy
    42  }
    43  
    44  // NewDoHNameServer creates DOH server object for remote resolving.
    45  func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (*DoHNameServer, error) {
    46  	newError("DNS: created Remote DOH client for ", url.String()).AtInfo().WriteToLog()
    47  	s := baseDOHNameServer(url, "DOH", queryStrategy)
    48  
    49  	s.dispatcher = dispatcher
    50  	tr := &http.Transport{
    51  		MaxIdleConns:        30,
    52  		IdleConnTimeout:     90 * time.Second,
    53  		TLSHandshakeTimeout: 30 * time.Second,
    54  		ForceAttemptHTTP2:   true,
    55  		DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
    56  			dest, err := net.ParseDestination(network + ":" + addr)
    57  			if err != nil {
    58  				return nil, err
    59  			}
    60  			link, err := s.dispatcher.Dispatch(toDnsContext(ctx, s.dohURL), dest)
    61  			select {
    62  			case <-ctx.Done():
    63  				return nil, ctx.Err()
    64  			default:
    65  
    66  			}
    67  			if err != nil {
    68  				return nil, err
    69  			}
    70  
    71  			cc := common.ChainedClosable{}
    72  			if cw, ok := link.Writer.(common.Closable); ok {
    73  				cc = append(cc, cw)
    74  			}
    75  			if cr, ok := link.Reader.(common.Closable); ok {
    76  				cc = append(cc, cr)
    77  			}
    78  			return cnc.NewConnection(
    79  				cnc.ConnectionInputMulti(link.Writer),
    80  				cnc.ConnectionOutputMulti(link.Reader),
    81  				cnc.ConnectionOnClose(cc),
    82  			), nil
    83  		},
    84  	}
    85  	s.httpClient = &http.Client{
    86  		Timeout:   time.Second * 180,
    87  		Transport: tr,
    88  	}
    89  
    90  	return s, nil
    91  }
    92  
    93  // NewDoHLocalNameServer creates DOH client object for local resolving
    94  func NewDoHLocalNameServer(url *url.URL, queryStrategy QueryStrategy) *DoHNameServer {
    95  	url.Scheme = "https"
    96  	s := baseDOHNameServer(url, "DOHL", queryStrategy)
    97  	tr := &http.Transport{
    98  		IdleConnTimeout:   90 * time.Second,
    99  		ForceAttemptHTTP2: true,
   100  		DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
   101  			dest, err := net.ParseDestination(network + ":" + addr)
   102  			if err != nil {
   103  				return nil, err
   104  			}
   105  			conn, err := internet.DialSystem(ctx, dest, nil)
   106  			log.Record(&log.AccessMessage{
   107  				From:   "DNS",
   108  				To:     s.dohURL,
   109  				Status: log.AccessAccepted,
   110  				Detour: "local",
   111  			})
   112  			if err != nil {
   113  				return nil, err
   114  			}
   115  			return conn, nil
   116  		},
   117  	}
   118  	s.httpClient = &http.Client{
   119  		Timeout:   time.Second * 180,
   120  		Transport: tr,
   121  	}
   122  	newError("DNS: created Local DOH client for ", url.String()).AtInfo().WriteToLog()
   123  	return s
   124  }
   125  
   126  func baseDOHNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy) *DoHNameServer {
   127  	s := &DoHNameServer{
   128  		ips:           make(map[string]*record),
   129  		pub:           pubsub.NewService(),
   130  		name:          prefix + "//" + url.Host,
   131  		dohURL:        url.String(),
   132  		queryStrategy: queryStrategy,
   133  	}
   134  	s.cleanup = &task.Periodic{
   135  		Interval: time.Minute,
   136  		Execute:  s.Cleanup,
   137  	}
   138  	return s
   139  }
   140  
   141  // Name implements Server.
   142  func (s *DoHNameServer) Name() string {
   143  	return s.name
   144  }
   145  
   146  // Cleanup clears expired items from cache
   147  func (s *DoHNameServer) Cleanup() error {
   148  	now := time.Now()
   149  	s.Lock()
   150  	defer s.Unlock()
   151  
   152  	if len(s.ips) == 0 {
   153  		return newError("nothing to do. stopping...")
   154  	}
   155  
   156  	for domain, record := range s.ips {
   157  		if record.A != nil && record.A.Expire.Before(now) {
   158  			record.A = nil
   159  		}
   160  		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
   161  			record.AAAA = nil
   162  		}
   163  
   164  		if record.A == nil && record.AAAA == nil {
   165  			newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
   166  			delete(s.ips, domain)
   167  		} else {
   168  			s.ips[domain] = record
   169  		}
   170  	}
   171  
   172  	if len(s.ips) == 0 {
   173  		s.ips = make(map[string]*record)
   174  	}
   175  
   176  	return nil
   177  }
   178  
   179  func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
   180  	elapsed := time.Since(req.start)
   181  
   182  	s.Lock()
   183  	rec, found := s.ips[req.domain]
   184  	if !found {
   185  		rec = &record{}
   186  	}
   187  	updated := false
   188  
   189  	switch req.reqType {
   190  	case dnsmessage.TypeA:
   191  		if isNewer(rec.A, ipRec) {
   192  			rec.A = ipRec
   193  			updated = true
   194  		}
   195  	case dnsmessage.TypeAAAA:
   196  		addr := make([]net.Address, 0, len(ipRec.IP))
   197  		for _, ip := range ipRec.IP {
   198  			if len(ip.IP()) == net.IPv6len {
   199  				addr = append(addr, ip)
   200  			}
   201  		}
   202  		ipRec.IP = addr
   203  		if isNewer(rec.AAAA, ipRec) {
   204  			rec.AAAA = ipRec
   205  			updated = true
   206  		}
   207  	}
   208  	newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
   209  
   210  	if updated {
   211  		s.ips[req.domain] = rec
   212  	}
   213  	switch req.reqType {
   214  	case dnsmessage.TypeA:
   215  		s.pub.Publish(req.domain+"4", nil)
   216  	case dnsmessage.TypeAAAA:
   217  		s.pub.Publish(req.domain+"6", nil)
   218  	}
   219  	s.Unlock()
   220  	common.Must(s.cleanup.Start())
   221  }
   222  
   223  func (s *DoHNameServer) newReqID() uint16 {
   224  	return uint16(atomic.AddUint32(&s.reqID, 1))
   225  }
   226  
   227  func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
   228  	newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx))
   229  
   230  	if s.name+"." == "DOH//"+domain {
   231  		newError(s.name, " tries to resolve itself! Use IP or set \"hosts\" instead.").AtError().WriteToLog(session.ExportIDToError(ctx))
   232  		return
   233  	}
   234  
   235  	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))
   236  
   237  	var deadline time.Time
   238  	if d, ok := ctx.Deadline(); ok {
   239  		deadline = d
   240  	} else {
   241  		deadline = time.Now().Add(time.Second * 5)
   242  	}
   243  
   244  	for _, req := range reqs {
   245  		go func(r *dnsRequest) {
   246  			// generate new context for each req, using same context
   247  			// may cause reqs all aborted if any one encounter an error
   248  			dnsCtx := ctx
   249  
   250  			// reserve internal dns server requested Inbound
   251  			if inbound := session.InboundFromContext(ctx); inbound != nil {
   252  				dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
   253  			}
   254  
   255  			dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
   256  				Protocol:       "https",
   257  				SkipDNSResolve: true,
   258  			})
   259  
   260  			// forced to use mux for DOH
   261  			// dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true)
   262  
   263  			var cancel context.CancelFunc
   264  			dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline)
   265  			defer cancel()
   266  
   267  			b, err := dns.PackMessage(r.msg)
   268  			if err != nil {
   269  				newError("failed to pack dns query for ", domain).Base(err).AtError().WriteToLog()
   270  				return
   271  			}
   272  			resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
   273  			if err != nil {
   274  				newError("failed to retrieve response for ", domain).Base(err).AtError().WriteToLog()
   275  				return
   276  			}
   277  			rec, err := parseResponse(resp)
   278  			if err != nil {
   279  				newError("failed to handle DOH response for ", domain).Base(err).AtError().WriteToLog()
   280  				return
   281  			}
   282  			s.updateIP(r, rec)
   283  		}(req)
   284  	}
   285  }
   286  
   287  func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) {
   288  	body := bytes.NewBuffer(b)
   289  	req, err := http.NewRequest("POST", s.dohURL, body)
   290  	if err != nil {
   291  		return nil, err
   292  	}
   293  
   294  	req.Header.Add("Accept", "application/dns-message")
   295  	req.Header.Add("Content-Type", "application/dns-message")
   296  
   297  	hc := s.httpClient
   298  
   299  	resp, err := hc.Do(req.WithContext(ctx))
   300  	if err != nil {
   301  		return nil, err
   302  	}
   303  
   304  	defer resp.Body.Close()
   305  	if resp.StatusCode != http.StatusOK {
   306  		io.Copy(io.Discard, resp.Body) // flush resp.Body so that the conn is reusable
   307  		return nil, fmt.Errorf("DOH server returned code %d", resp.StatusCode)
   308  	}
   309  
   310  	return io.ReadAll(resp.Body)
   311  }
   312  
   313  func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) {
   314  	s.RLock()
   315  	record, found := s.ips[domain]
   316  	s.RUnlock()
   317  
   318  	if !found {
   319  		return nil, errRecordNotFound
   320  	}
   321  
   322  	var err4 error
   323  	var err6 error
   324  	var ips []net.Address
   325  	var ip6 []net.Address
   326  
   327  	if option.IPv4Enable {
   328  		ips, err4 = record.A.getIPs()
   329  	}
   330  
   331  	if option.IPv6Enable {
   332  		ip6, err6 = record.AAAA.getIPs()
   333  		ips = append(ips, ip6...)
   334  	}
   335  
   336  	if len(ips) > 0 {
   337  		return toNetIP(ips)
   338  	}
   339  
   340  	if err4 != nil {
   341  		return nil, err4
   342  	}
   343  
   344  	if err6 != nil {
   345  		return nil, err6
   346  	}
   347  
   348  	if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
   349  		return nil, dns_feature.ErrEmptyResponse
   350  	}
   351  
   352  	return nil, errRecordNotFound
   353  }
   354  
   355  // QueryIP implements Server.
   356  func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { // nolint: dupl
   357  	fqdn := Fqdn(domain)
   358  	option = ResolveIpOptionOverride(s.queryStrategy, option)
   359  	if !option.IPv4Enable && !option.IPv6Enable {
   360  		return nil, dns_feature.ErrEmptyResponse
   361  	}
   362  
   363  	if disableCache {
   364  		newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
   365  	} else {
   366  		ips, err := s.findIPsForDomain(fqdn, option)
   367  		if err != errRecordNotFound {
   368  			newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
   369  			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
   370  			return ips, err
   371  		}
   372  	}
   373  
   374  	// ipv4 and ipv6 belong to different subscription groups
   375  	var sub4, sub6 *pubsub.Subscriber
   376  	if option.IPv4Enable {
   377  		sub4 = s.pub.Subscribe(fqdn + "4")
   378  		defer sub4.Close()
   379  	}
   380  	if option.IPv6Enable {
   381  		sub6 = s.pub.Subscribe(fqdn + "6")
   382  		defer sub6.Close()
   383  	}
   384  	done := make(chan interface{})
   385  	go func() {
   386  		if sub4 != nil {
   387  			select {
   388  			case <-sub4.Wait():
   389  			case <-ctx.Done():
   390  			}
   391  		}
   392  		if sub6 != nil {
   393  			select {
   394  			case <-sub6.Wait():
   395  			case <-ctx.Done():
   396  			}
   397  		}
   398  		close(done)
   399  	}()
   400  	s.sendQuery(ctx, fqdn, clientIP, option)
   401  	start := time.Now()
   402  
   403  	for {
   404  		ips, err := s.findIPsForDomain(fqdn, option)
   405  		if err != errRecordNotFound {
   406  			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
   407  			return ips, err
   408  		}
   409  
   410  		select {
   411  		case <-ctx.Done():
   412  			return nil, ctx.Err()
   413  		case <-done:
   414  		}
   415  	}
   416  }