github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/app/dns/dns.go (about)

     1  // Package dns is an implementation of core.DNS feature.
     2  package dns
     3  
     4  //go:generate go run github.com/xmplusdev/xmcore/common/errors/errorgen
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/xmplusdev/xmcore/app/router"
    13  	"github.com/xmplusdev/xmcore/common"
    14  	"github.com/xmplusdev/xmcore/common/errors"
    15  	"github.com/xmplusdev/xmcore/common/net"
    16  	"github.com/xmplusdev/xmcore/common/session"
    17  	"github.com/xmplusdev/xmcore/common/strmatcher"
    18  	"github.com/xmplusdev/xmcore/features"
    19  	"github.com/xmplusdev/xmcore/features/dns"
    20  )
    21  
    22  // DNS is a DNS rely server.
    23  type DNS struct {
    24  	sync.Mutex
    25  	tag                    string
    26  	disableCache           bool
    27  	disableFallback        bool
    28  	disableFallbackIfMatch bool
    29  	ipOption               *dns.IPOption
    30  	hosts                  *StaticHosts
    31  	clients                []*Client
    32  	ctx                    context.Context
    33  	domainMatcher          strmatcher.IndexMatcher
    34  	matcherInfos           []*DomainMatcherInfo
    35  }
    36  
    37  // DomainMatcherInfo contains information attached to index returned by Server.domainMatcher
    38  type DomainMatcherInfo struct {
    39  	clientIdx     uint16
    40  	domainRuleIdx uint16
    41  }
    42  
    43  // New creates a new DNS server with given configuration.
    44  func New(ctx context.Context, config *Config) (*DNS, error) {
    45  	var tag string
    46  	if len(config.Tag) > 0 {
    47  		tag = config.Tag
    48  	} else {
    49  		tag = generateRandomTag()
    50  	}
    51  
    52  	var clientIP net.IP
    53  	switch len(config.ClientIp) {
    54  	case 0, net.IPv4len, net.IPv6len:
    55  		clientIP = net.IP(config.ClientIp)
    56  	default:
    57  		return nil, newError("unexpected client IP length ", len(config.ClientIp))
    58  	}
    59  
    60  	var ipOption *dns.IPOption
    61  	switch config.QueryStrategy {
    62  	case QueryStrategy_USE_IP:
    63  		ipOption = &dns.IPOption{
    64  			IPv4Enable: true,
    65  			IPv6Enable: true,
    66  			FakeEnable: false,
    67  		}
    68  	case QueryStrategy_USE_IP4:
    69  		ipOption = &dns.IPOption{
    70  			IPv4Enable: true,
    71  			IPv6Enable: false,
    72  			FakeEnable: false,
    73  		}
    74  	case QueryStrategy_USE_IP6:
    75  		ipOption = &dns.IPOption{
    76  			IPv4Enable: false,
    77  			IPv6Enable: true,
    78  			FakeEnable: false,
    79  		}
    80  	}
    81  
    82  	hosts, err := NewStaticHosts(config.StaticHosts, config.Hosts)
    83  	if err != nil {
    84  		return nil, newError("failed to create hosts").Base(err)
    85  	}
    86  
    87  	clients := []*Client{}
    88  	domainRuleCount := 0
    89  	for _, ns := range config.NameServer {
    90  		domainRuleCount += len(ns.PrioritizedDomain)
    91  	}
    92  
    93  	// MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1
    94  	matcherInfos := make([]*DomainMatcherInfo, domainRuleCount+1)
    95  	domainMatcher := &strmatcher.MatcherGroup{}
    96  	geoipContainer := router.GeoIPMatcherContainer{}
    97  
    98  	for _, endpoint := range config.NameServers {
    99  		features.PrintDeprecatedFeatureWarning("simple DNS server")
   100  		client, err := NewSimpleClient(ctx, endpoint, clientIP)
   101  		if err != nil {
   102  			return nil, newError("failed to create client").Base(err)
   103  		}
   104  		clients = append(clients, client)
   105  	}
   106  
   107  	for _, ns := range config.NameServer {
   108  		clientIdx := len(clients)
   109  		updateDomain := func(domainRule strmatcher.Matcher, originalRuleIdx int, matcherInfos []*DomainMatcherInfo) error {
   110  			midx := domainMatcher.Add(domainRule)
   111  			matcherInfos[midx] = &DomainMatcherInfo{
   112  				clientIdx:     uint16(clientIdx),
   113  				domainRuleIdx: uint16(originalRuleIdx),
   114  			}
   115  			return nil
   116  		}
   117  
   118  		myClientIP := clientIP
   119  		switch len(ns.ClientIp) {
   120  		case net.IPv4len, net.IPv6len:
   121  			myClientIP = net.IP(ns.ClientIp)
   122  		}
   123  		client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain)
   124  		if err != nil {
   125  			return nil, newError("failed to create client").Base(err)
   126  		}
   127  		clients = append(clients, client)
   128  	}
   129  
   130  	// If there is no DNS client in config, add a `localhost` DNS client
   131  	if len(clients) == 0 {
   132  		clients = append(clients, NewLocalDNSClient())
   133  	}
   134  
   135  	return &DNS{
   136  		tag:                    tag,
   137  		hosts:                  hosts,
   138  		ipOption:               ipOption,
   139  		clients:                clients,
   140  		ctx:                    ctx,
   141  		domainMatcher:          domainMatcher,
   142  		matcherInfos:           matcherInfos,
   143  		disableCache:           config.DisableCache,
   144  		disableFallback:        config.DisableFallback,
   145  		disableFallbackIfMatch: config.DisableFallbackIfMatch,
   146  	}, nil
   147  }
   148  
   149  // Type implements common.HasType.
   150  func (*DNS) Type() interface{} {
   151  	return dns.ClientType()
   152  }
   153  
   154  // Start implements common.Runnable.
   155  func (s *DNS) Start() error {
   156  	return nil
   157  }
   158  
   159  // Close implements common.Closable.
   160  func (s *DNS) Close() error {
   161  	return nil
   162  }
   163  
   164  // IsOwnLink implements proxy.dns.ownLinkVerifier
   165  func (s *DNS) IsOwnLink(ctx context.Context) bool {
   166  	inbound := session.InboundFromContext(ctx)
   167  	return inbound != nil && inbound.Tag == s.tag
   168  }
   169  
   170  // LookupIP implements dns.Client.
   171  func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, error) {
   172  	if domain == "" {
   173  		return nil, newError("empty domain name")
   174  	}
   175  
   176  	option.IPv4Enable = option.IPv4Enable && s.ipOption.IPv4Enable
   177  	option.IPv6Enable = option.IPv6Enable && s.ipOption.IPv6Enable
   178  
   179  	if !option.IPv4Enable && !option.IPv6Enable {
   180  		return nil, dns.ErrEmptyResponse
   181  	}
   182  
   183  	// Normalize the FQDN form query
   184  	domain = strings.TrimSuffix(domain, ".")
   185  
   186  	// Static host lookup
   187  	switch addrs := s.hosts.Lookup(domain, option); {
   188  	case addrs == nil: // Domain not recorded in static host
   189  		break
   190  	case len(addrs) == 0: // Domain recorded, but no valid IP returned (e.g. IPv4 address with only IPv6 enabled)
   191  		return nil, dns.ErrEmptyResponse
   192  	case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Domain replacement
   193  		newError("domain replaced: ", domain, " -> ", addrs[0].Domain()).WriteToLog()
   194  		domain = addrs[0].Domain()
   195  	default: // Successfully found ip records in static host
   196  		newError("returning ", len(addrs), " IP(s) for domain ", domain, " -> ", addrs).WriteToLog()
   197  		return toNetIP(addrs)
   198  	}
   199  
   200  	// Name servers lookup
   201  	errs := []error{}
   202  	ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag})
   203  	for _, client := range s.sortClients(domain) {
   204  		if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
   205  			newError("skip DNS resolution for domain ", domain, " at server ", client.Name()).AtDebug().WriteToLog()
   206  			continue
   207  		}
   208  		ips, err := client.QueryIP(ctx, domain, option, s.disableCache)
   209  		if len(ips) > 0 {
   210  			return ips, nil
   211  		}
   212  		if err != nil {
   213  			newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
   214  			errs = append(errs, err)
   215  		}
   216  		// 5 for RcodeRefused in miekg/dns, hardcode to reduce binary size
   217  		if err != context.Canceled && err != context.DeadlineExceeded && err != errExpectedIPNonMatch && err != dns.ErrEmptyResponse && dns.RCodeFromError(err) != 5 {
   218  			return nil, err
   219  		}
   220  	}
   221  
   222  	return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...))
   223  }
   224  
   225  // LookupHosts implements dns.HostsLookup.
   226  func (s *DNS) LookupHosts(domain string) *net.Address {
   227  	domain = strings.TrimSuffix(domain, ".")
   228  	if domain == "" {
   229  		return nil
   230  	}
   231  	// Normalize the FQDN form query
   232  	addrs := s.hosts.Lookup(domain, *s.ipOption)
   233  	if len(addrs) > 0 {
   234  		newError("domain replaced: ", domain, " -> ", addrs[0].String()).AtInfo().WriteToLog()
   235  		return &addrs[0]
   236  	}
   237  
   238  	return nil
   239  }
   240  
   241  // GetIPOption implements ClientWithIPOption.
   242  func (s *DNS) GetIPOption() *dns.IPOption {
   243  	return s.ipOption
   244  }
   245  
   246  // SetQueryOption implements ClientWithIPOption.
   247  func (s *DNS) SetQueryOption(isIPv4Enable, isIPv6Enable bool) {
   248  	s.ipOption.IPv4Enable = isIPv4Enable
   249  	s.ipOption.IPv6Enable = isIPv6Enable
   250  }
   251  
   252  // SetFakeDNSOption implements ClientWithIPOption.
   253  func (s *DNS) SetFakeDNSOption(isFakeEnable bool) {
   254  	s.ipOption.FakeEnable = isFakeEnable
   255  }
   256  
   257  func (s *DNS) sortClients(domain string) []*Client {
   258  	clients := make([]*Client, 0, len(s.clients))
   259  	clientUsed := make([]bool, len(s.clients))
   260  	clientNames := make([]string, 0, len(s.clients))
   261  	domainRules := []string{}
   262  
   263  	// Priority domain matching
   264  	hasMatch := false
   265  	for _, match := range s.domainMatcher.Match(domain) {
   266  		info := s.matcherInfos[match]
   267  		client := s.clients[info.clientIdx]
   268  		domainRule := client.domains[info.domainRuleIdx]
   269  		domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx))
   270  		if clientUsed[info.clientIdx] {
   271  			continue
   272  		}
   273  		clientUsed[info.clientIdx] = true
   274  		clients = append(clients, client)
   275  		clientNames = append(clientNames, client.Name())
   276  		hasMatch = true
   277  	}
   278  
   279  	if !(s.disableFallback || s.disableFallbackIfMatch && hasMatch) {
   280  		// Default round-robin query
   281  		for idx, client := range s.clients {
   282  			if clientUsed[idx] || client.skipFallback {
   283  				continue
   284  			}
   285  			clientUsed[idx] = true
   286  			clients = append(clients, client)
   287  			clientNames = append(clientNames, client.Name())
   288  		}
   289  	}
   290  
   291  	if len(domainRules) > 0 {
   292  		newError("domain ", domain, " matches following rules: ", domainRules).AtDebug().WriteToLog()
   293  	}
   294  	if len(clientNames) > 0 {
   295  		newError("domain ", domain, " will use DNS in order: ", clientNames).AtDebug().WriteToLog()
   296  	}
   297  
   298  	if len(clients) == 0 {
   299  		clients = append(clients, s.clients[0])
   300  		clientNames = append(clientNames, s.clients[0].Name())
   301  		newError("domain ", domain, " will use the first DNS: ", clientNames).AtDebug().WriteToLog()
   302  	}
   303  
   304  	return clients
   305  }
   306  
   307  func init() {
   308  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   309  		return New(ctx, config.(*Config))
   310  	}))
   311  }