github.com/moqsien/xraycore@v1.8.5/app/dns/dns.go (about)

     1  // Package dns is an implementation of core.DNS feature.
     2  package dns
     3  
     4  //go:generate go run github.com/moqsien/xraycore/common/errors/errorgen
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/moqsien/xraycore/app/router"
    13  	"github.com/moqsien/xraycore/common"
    14  	"github.com/moqsien/xraycore/common/errors"
    15  	"github.com/moqsien/xraycore/common/net"
    16  	"github.com/moqsien/xraycore/common/session"
    17  	"github.com/moqsien/xraycore/common/strmatcher"
    18  	"github.com/moqsien/xraycore/features"
    19  	"github.com/moqsien/xraycore/features/dns"
    20  )
    21  
    22  // DNS is a DNS rely server.
    23  type DNS struct {
    24  	sync.Mutex
    25  	tag                    string
    26  	disableCache           bool
    27  	disableFallback        bool
    28  	disableFallbackIfMatch bool
    29  	ipOption               *dns.IPOption
    30  	hosts                  *StaticHosts
    31  	clients                []*Client
    32  	ctx                    context.Context
    33  	domainMatcher          strmatcher.IndexMatcher
    34  	matcherInfos           []*DomainMatcherInfo
    35  }
    36  
    37  // DomainMatcherInfo contains information attached to index returned by Server.domainMatcher
    38  type DomainMatcherInfo struct {
    39  	clientIdx     uint16
    40  	domainRuleIdx uint16
    41  }
    42  
    43  // New creates a new DNS server with given configuration.
    44  func New(ctx context.Context, config *Config) (*DNS, error) {
    45  	var tag string
    46  	if len(config.Tag) > 0 {
    47  		tag = config.Tag
    48  	} else {
    49  		tag = generateRandomTag()
    50  	}
    51  
    52  	var clientIP net.IP
    53  	switch len(config.ClientIp) {
    54  	case 0, net.IPv4len, net.IPv6len:
    55  		clientIP = net.IP(config.ClientIp)
    56  	default:
    57  		return nil, newError("unexpected client IP length ", len(config.ClientIp))
    58  	}
    59  
    60  	var ipOption *dns.IPOption
    61  	switch config.QueryStrategy {
    62  	case QueryStrategy_USE_IP:
    63  		ipOption = &dns.IPOption{
    64  			IPv4Enable: true,
    65  			IPv6Enable: true,
    66  			FakeEnable: false,
    67  		}
    68  	case QueryStrategy_USE_IP4:
    69  		ipOption = &dns.IPOption{
    70  			IPv4Enable: true,
    71  			IPv6Enable: false,
    72  			FakeEnable: false,
    73  		}
    74  	case QueryStrategy_USE_IP6:
    75  		ipOption = &dns.IPOption{
    76  			IPv4Enable: false,
    77  			IPv6Enable: true,
    78  			FakeEnable: false,
    79  		}
    80  	}
    81  
    82  	hosts, err := NewStaticHosts(config.StaticHosts, config.Hosts)
    83  	if err != nil {
    84  		return nil, newError("failed to create hosts").Base(err)
    85  	}
    86  
    87  	clients := []*Client{}
    88  	domainRuleCount := 0
    89  	for _, ns := range config.NameServer {
    90  		domainRuleCount += len(ns.PrioritizedDomain)
    91  	}
    92  
    93  	// MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1
    94  	matcherInfos := make([]*DomainMatcherInfo, domainRuleCount+1)
    95  	domainMatcher := &strmatcher.MatcherGroup{}
    96  	geoipContainer := router.GeoIPMatcherContainer{}
    97  
    98  	for _, endpoint := range config.NameServers {
    99  		features.PrintDeprecatedFeatureWarning("simple DNS server")
   100  		client, err := NewSimpleClient(ctx, endpoint, clientIP)
   101  		if err != nil {
   102  			return nil, newError("failed to create client").Base(err)
   103  		}
   104  		clients = append(clients, client)
   105  	}
   106  
   107  	for _, ns := range config.NameServer {
   108  		clientIdx := len(clients)
   109  		updateDomain := func(domainRule strmatcher.Matcher, originalRuleIdx int, matcherInfos []*DomainMatcherInfo) error {
   110  			midx := domainMatcher.Add(domainRule)
   111  			matcherInfos[midx] = &DomainMatcherInfo{
   112  				clientIdx:     uint16(clientIdx),
   113  				domainRuleIdx: uint16(originalRuleIdx),
   114  			}
   115  			return nil
   116  		}
   117  
   118  		myClientIP := clientIP
   119  		switch len(ns.ClientIp) {
   120  		case net.IPv4len, net.IPv6len:
   121  			myClientIP = net.IP(ns.ClientIp)
   122  		}
   123  		client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain)
   124  		if err != nil {
   125  			return nil, newError("failed to create client").Base(err)
   126  		}
   127  		clients = append(clients, client)
   128  	}
   129  
   130  	// If there is no DNS client in config, add a `localhost` DNS client
   131  	if len(clients) == 0 {
   132  		clients = append(clients, NewLocalDNSClient())
   133  	}
   134  
   135  	return &DNS{
   136  		tag:                    tag,
   137  		hosts:                  hosts,
   138  		ipOption:               ipOption,
   139  		clients:                clients,
   140  		ctx:                    ctx,
   141  		domainMatcher:          domainMatcher,
   142  		matcherInfos:           matcherInfos,
   143  		disableCache:           config.DisableCache,
   144  		disableFallback:        config.DisableFallback,
   145  		disableFallbackIfMatch: config.DisableFallbackIfMatch,
   146  	}, nil
   147  }
   148  
   149  // Type implements common.HasType.
   150  func (*DNS) Type() interface{} {
   151  	return dns.ClientType()
   152  }
   153  
   154  // Start implements common.Runnable.
   155  func (s *DNS) Start() error {
   156  	return nil
   157  }
   158  
   159  // Close implements common.Closable.
   160  func (s *DNS) Close() error {
   161  	return nil
   162  }
   163  
   164  // IsOwnLink implements proxy.dns.ownLinkVerifier
   165  func (s *DNS) IsOwnLink(ctx context.Context) bool {
   166  	inbound := session.InboundFromContext(ctx)
   167  	return inbound != nil && inbound.Tag == s.tag
   168  }
   169  
   170  // LookupIP implements dns.Client.
   171  func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, error) {
   172  	if domain == "" {
   173  		return nil, newError("empty domain name")
   174  	}
   175  
   176  	option.IPv4Enable = option.IPv4Enable && s.ipOption.IPv4Enable
   177  	option.IPv6Enable = option.IPv6Enable && s.ipOption.IPv6Enable
   178  
   179  	if !option.IPv4Enable && !option.IPv6Enable {
   180  		return nil, dns.ErrEmptyResponse
   181  	}
   182  
   183  	// Normalize the FQDN form query
   184  	if strings.HasSuffix(domain, ".") {
   185  		domain = domain[:len(domain)-1]
   186  	}
   187  
   188  	// Static host lookup
   189  	switch addrs := s.hosts.Lookup(domain, option); {
   190  	case addrs == nil: // Domain not recorded in static host
   191  		break
   192  	case len(addrs) == 0: // Domain recorded, but no valid IP returned (e.g. IPv4 address with only IPv6 enabled)
   193  		return nil, dns.ErrEmptyResponse
   194  	case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Domain replacement
   195  		newError("domain replaced: ", domain, " -> ", addrs[0].Domain()).WriteToLog()
   196  		domain = addrs[0].Domain()
   197  	default: // Successfully found ip records in static host
   198  		newError("returning ", len(addrs), " IP(s) for domain ", domain, " -> ", addrs).WriteToLog()
   199  		return toNetIP(addrs)
   200  	}
   201  
   202  	// Name servers lookup
   203  	errs := []error{}
   204  	ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag})
   205  	for _, client := range s.sortClients(domain) {
   206  		if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
   207  			newError("skip DNS resolution for domain ", domain, " at server ", client.Name()).AtDebug().WriteToLog()
   208  			continue
   209  		}
   210  		ips, err := client.QueryIP(ctx, domain, option, s.disableCache)
   211  		if len(ips) > 0 {
   212  			return ips, nil
   213  		}
   214  		if err != nil {
   215  			newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
   216  			errs = append(errs, err)
   217  		}
   218  		if err != context.Canceled && err != context.DeadlineExceeded && err != errExpectedIPNonMatch && err != dns.ErrEmptyResponse {
   219  			return nil, err
   220  		}
   221  	}
   222  
   223  	return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...))
   224  }
   225  
   226  // LookupHosts implements dns.HostsLookup.
   227  func (s *DNS) LookupHosts(domain string) *net.Address {
   228  	domain = strings.TrimSuffix(domain, ".")
   229  	if domain == "" {
   230  		return nil
   231  	}
   232  	// Normalize the FQDN form query
   233  	addrs := s.hosts.Lookup(domain, *s.ipOption)
   234  	if len(addrs) > 0 {
   235  		newError("domain replaced: ", domain, " -> ", addrs[0].String()).AtInfo().WriteToLog()
   236  		return &addrs[0]
   237  	}
   238  
   239  	return nil
   240  }
   241  
   242  // GetIPOption implements ClientWithIPOption.
   243  func (s *DNS) GetIPOption() *dns.IPOption {
   244  	return s.ipOption
   245  }
   246  
   247  // SetQueryOption implements ClientWithIPOption.
   248  func (s *DNS) SetQueryOption(isIPv4Enable, isIPv6Enable bool) {
   249  	s.ipOption.IPv4Enable = isIPv4Enable
   250  	s.ipOption.IPv6Enable = isIPv6Enable
   251  }
   252  
   253  // SetFakeDNSOption implements ClientWithIPOption.
   254  func (s *DNS) SetFakeDNSOption(isFakeEnable bool) {
   255  	s.ipOption.FakeEnable = isFakeEnable
   256  }
   257  
   258  func (s *DNS) sortClients(domain string) []*Client {
   259  	clients := make([]*Client, 0, len(s.clients))
   260  	clientUsed := make([]bool, len(s.clients))
   261  	clientNames := make([]string, 0, len(s.clients))
   262  	domainRules := []string{}
   263  
   264  	// Priority domain matching
   265  	hasMatch := false
   266  	for _, match := range s.domainMatcher.Match(domain) {
   267  		info := s.matcherInfos[match]
   268  		client := s.clients[info.clientIdx]
   269  		domainRule := client.domains[info.domainRuleIdx]
   270  		domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx))
   271  		if clientUsed[info.clientIdx] {
   272  			continue
   273  		}
   274  		clientUsed[info.clientIdx] = true
   275  		clients = append(clients, client)
   276  		clientNames = append(clientNames, client.Name())
   277  		hasMatch = true
   278  	}
   279  
   280  	if !(s.disableFallback || s.disableFallbackIfMatch && hasMatch) {
   281  		// Default round-robin query
   282  		for idx, client := range s.clients {
   283  			if clientUsed[idx] || client.skipFallback {
   284  				continue
   285  			}
   286  			clientUsed[idx] = true
   287  			clients = append(clients, client)
   288  			clientNames = append(clientNames, client.Name())
   289  		}
   290  	}
   291  
   292  	if len(domainRules) > 0 {
   293  		newError("domain ", domain, " matches following rules: ", domainRules).AtDebug().WriteToLog()
   294  	}
   295  	if len(clientNames) > 0 {
   296  		newError("domain ", domain, " will use DNS in order: ", clientNames).AtDebug().WriteToLog()
   297  	}
   298  
   299  	if len(clients) == 0 {
   300  		clients = append(clients, s.clients[0])
   301  		clientNames = append(clientNames, s.clients[0].Name())
   302  		newError("domain ", domain, " will use the first DNS: ", clientNames).AtDebug().WriteToLog()
   303  	}
   304  
   305  	return clients
   306  }
   307  
   308  func init() {
   309  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   310  		return New(ctx, config.(*Config))
   311  	}))
   312  }