github.com/imannamdari/v2ray-core/v5@v5.0.5/app/dns/dns.go (about)

     1  //go:build !confonly
     2  // +build !confonly
     3  
     4  // Package dns is an implementation of core.DNS feature.
     5  package dns
     6  
     7  //go:generate go run github.com/imannamdari/v2ray-core/v5/common/errors/errorgen
     8  
     9  import (
    10  	"context"
    11  	"fmt"
    12  	"strings"
    13  	"sync"
    14  
    15  	core "github.com/imannamdari/v2ray-core/v5"
    16  	"github.com/imannamdari/v2ray-core/v5/app/dns/fakedns"
    17  	"github.com/imannamdari/v2ray-core/v5/app/router"
    18  	"github.com/imannamdari/v2ray-core/v5/common"
    19  	"github.com/imannamdari/v2ray-core/v5/common/errors"
    20  	"github.com/imannamdari/v2ray-core/v5/common/net"
    21  	"github.com/imannamdari/v2ray-core/v5/common/platform"
    22  	"github.com/imannamdari/v2ray-core/v5/common/session"
    23  	"github.com/imannamdari/v2ray-core/v5/common/strmatcher"
    24  	"github.com/imannamdari/v2ray-core/v5/features"
    25  	"github.com/imannamdari/v2ray-core/v5/features/dns"
    26  	"github.com/imannamdari/v2ray-core/v5/infra/conf/cfgcommon"
    27  	"github.com/imannamdari/v2ray-core/v5/infra/conf/geodata"
    28  )
    29  
    30  // DNS is a DNS rely server.
    31  type DNS struct {
    32  	sync.Mutex
    33  	hosts         *StaticHosts
    34  	clients       []*Client
    35  	ctx           context.Context
    36  	clientTags    map[string]bool
    37  	fakeDNSEngine *FakeDNSEngine
    38  	domainMatcher strmatcher.IndexMatcher
    39  	matcherInfos  []DomainMatcherInfo
    40  }
    41  
    42  // DomainMatcherInfo contains information attached to index returned by Server.domainMatcher
    43  type DomainMatcherInfo struct {
    44  	clientIdx     uint16
    45  	domainRuleIdx uint16
    46  }
    47  
    48  // New creates a new DNS server with given configuration.
    49  func New(ctx context.Context, config *Config) (*DNS, error) {
    50  	// Create static hosts
    51  	hosts, err := NewStaticHosts(config.StaticHosts, config.Hosts)
    52  	if err != nil {
    53  		return nil, newError("failed to create hosts").Base(err)
    54  	}
    55  
    56  	// Create name servers from legacy configs
    57  	clients := []*Client{}
    58  	for _, endpoint := range config.NameServers {
    59  		features.PrintDeprecatedFeatureWarning("simple DNS server")
    60  		client, err := NewClient(ctx, &NameServer{Address: endpoint}, config)
    61  		if err != nil {
    62  			return nil, newError("failed to create client").Base(err)
    63  		}
    64  		clients = append(clients, client)
    65  	}
    66  
    67  	// Create name servers
    68  	nsClientMap := map[int]int{}
    69  	for nsIdx, ns := range config.NameServer {
    70  		client, err := NewClient(ctx, ns, config)
    71  		if err != nil {
    72  			return nil, newError("failed to create client").Base(err)
    73  		}
    74  		nsClientMap[nsIdx] = len(clients)
    75  		clients = append(clients, client)
    76  	}
    77  
    78  	// If there is no DNS client in config, add a `localhost` DNS client
    79  	if len(clients) == 0 {
    80  		clients = append(clients, NewLocalDNSClient())
    81  	}
    82  
    83  	s := &DNS{
    84  		hosts:   hosts,
    85  		clients: clients,
    86  		ctx:     ctx,
    87  	}
    88  
    89  	// Establish members related to global DNS state
    90  	s.clientTags = make(map[string]bool)
    91  	for _, client := range clients {
    92  		s.clientTags[client.tag] = true
    93  	}
    94  	if err := establishDomainRules(s, config, nsClientMap); err != nil {
    95  		return nil, err
    96  	}
    97  	if err := establishExpectedIPs(s, config, nsClientMap); err != nil {
    98  		return nil, err
    99  	}
   100  	if err := establishFakeDNS(s, config, nsClientMap); err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	return s, nil
   105  }
   106  
   107  func establishDomainRules(s *DNS, config *Config, nsClientMap map[int]int) error {
   108  	domainRuleCount := 0
   109  	for _, ns := range config.NameServer {
   110  		domainRuleCount += len(ns.PrioritizedDomain)
   111  	}
   112  	// MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1
   113  	matcherInfos := make([]DomainMatcherInfo, domainRuleCount+1)
   114  	var domainMatcher strmatcher.IndexMatcher
   115  	switch config.DomainMatcher {
   116  	case "mph", "hybrid":
   117  		newError("using mph domain matcher").AtDebug().WriteToLog()
   118  		domainMatcher = strmatcher.NewMphIndexMatcher()
   119  	case "linear":
   120  		fallthrough
   121  	default:
   122  		newError("using default domain matcher").AtDebug().WriteToLog()
   123  		domainMatcher = strmatcher.NewLinearIndexMatcher()
   124  	}
   125  	for nsIdx, ns := range config.NameServer {
   126  		clientIdx := nsClientMap[nsIdx]
   127  		var rules []string
   128  		ruleCurr := 0
   129  		ruleIter := 0
   130  		for _, domain := range ns.PrioritizedDomain {
   131  			domainRule, err := toStrMatcher(domain.Type, domain.Domain)
   132  			if err != nil {
   133  				return newError("failed to create prioritized domain").Base(err).AtWarning()
   134  			}
   135  			originalRuleIdx := ruleCurr
   136  			if ruleCurr < len(ns.OriginalRules) {
   137  				rule := ns.OriginalRules[ruleCurr]
   138  				if ruleCurr >= len(rules) {
   139  					rules = append(rules, rule.Rule)
   140  				}
   141  				ruleIter++
   142  				if ruleIter >= int(rule.Size) {
   143  					ruleIter = 0
   144  					ruleCurr++
   145  				}
   146  			} else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests)
   147  				rules = append(rules, domainRule.String())
   148  				ruleCurr++
   149  			}
   150  			midx := domainMatcher.Add(domainRule)
   151  			matcherInfos[midx] = DomainMatcherInfo{
   152  				clientIdx:     uint16(clientIdx),
   153  				domainRuleIdx: uint16(originalRuleIdx),
   154  			}
   155  			if err != nil {
   156  				return newError("failed to create prioritized domain").Base(err).AtWarning()
   157  			}
   158  		}
   159  		s.clients[clientIdx].domains = rules
   160  	}
   161  	if err := domainMatcher.Build(); err != nil {
   162  		return err
   163  	}
   164  	s.domainMatcher = domainMatcher
   165  	s.matcherInfos = matcherInfos
   166  	return nil
   167  }
   168  
   169  func establishExpectedIPs(s *DNS, config *Config, nsClientMap map[int]int) error {
   170  	geoipContainer := router.GeoIPMatcherContainer{}
   171  	for nsIdx, ns := range config.NameServer {
   172  		clientIdx := nsClientMap[nsIdx]
   173  		var matchers []*router.GeoIPMatcher
   174  		for _, geoip := range ns.Geoip {
   175  			matcher, err := geoipContainer.Add(geoip)
   176  			if err != nil {
   177  				return newError("failed to create ip matcher").Base(err).AtWarning()
   178  			}
   179  			matchers = append(matchers, matcher)
   180  		}
   181  		s.clients[clientIdx].expectIPs = matchers
   182  	}
   183  	return nil
   184  }
   185  
   186  func establishFakeDNS(s *DNS, config *Config, nsClientMap map[int]int) error {
   187  	fakeHolders := &fakedns.HolderMulti{}
   188  	fakeDefault := (*fakedns.HolderMulti)(nil)
   189  	if config.FakeDns != nil {
   190  		defaultEngine, err := fakeHolders.AddPoolMulti(config.FakeDns)
   191  		if err != nil {
   192  			return newError("fail to create fake dns").Base(err).AtWarning()
   193  		}
   194  		fakeDefault = defaultEngine
   195  	}
   196  	for nsIdx, ns := range config.NameServer {
   197  		clientIdx := nsClientMap[nsIdx]
   198  		if ns.FakeDns == nil {
   199  			continue
   200  		}
   201  		engine, err := fakeHolders.AddPoolMulti(ns.FakeDns)
   202  		if err != nil {
   203  			return newError("fail to create fake dns").Base(err).AtWarning()
   204  		}
   205  		s.clients[clientIdx].fakeDNS = NewFakeDNSServer(engine)
   206  		s.clients[clientIdx].queryStrategy.FakeEnable = true
   207  	}
   208  	// Do not create FakeDNSEngine feature if no FakeDNS server is configured
   209  	if fakeHolders.IsEmpty() {
   210  		return nil
   211  	}
   212  	// Add FakeDNSEngine feature when DNS feature is added for the first time
   213  	s.fakeDNSEngine = &FakeDNSEngine{dns: s, fakeHolders: fakeHolders, fakeDefault: fakeDefault}
   214  	return core.RequireFeatures(s.ctx, func(client dns.Client) error {
   215  		v := core.MustFromContext(s.ctx)
   216  		if v.GetFeature(dns.FakeDNSEngineType()) != nil {
   217  			return nil
   218  		}
   219  		if client, ok := client.(dns.ClientWithFakeDNS); ok {
   220  			return v.AddFeature(client.AsFakeDNSEngine())
   221  		}
   222  		return nil
   223  	})
   224  }
   225  
   226  // Type implements common.HasType.
   227  func (*DNS) Type() interface{} {
   228  	return dns.ClientType()
   229  }
   230  
   231  // Start implements common.Runnable.
   232  func (s *DNS) Start() error {
   233  	return nil
   234  }
   235  
   236  // Close implements common.Closable.
   237  func (s *DNS) Close() error {
   238  	return nil
   239  }
   240  
   241  // IsOwnLink implements proxy.dns.ownLinkVerifier
   242  func (s *DNS) IsOwnLink(ctx context.Context) bool {
   243  	inbound := session.InboundFromContext(ctx)
   244  	return inbound != nil && s.clientTags[inbound.Tag]
   245  }
   246  
   247  // AsFakeDNSClient implements dns.ClientWithFakeDNS.
   248  func (s *DNS) AsFakeDNSClient() dns.Client {
   249  	return &FakeDNSClient{DNS: s}
   250  }
   251  
   252  // AsFakeDNSEngine implements dns.ClientWithFakeDNS.
   253  func (s *DNS) AsFakeDNSEngine() dns.FakeDNSEngine {
   254  	return s.fakeDNSEngine
   255  }
   256  
   257  // LookupIP implements dns.Client.
   258  func (s *DNS) LookupIP(domain string) ([]net.IP, error) {
   259  	return s.lookupIPInternal(domain, dns.IPOption{IPv4Enable: true, IPv6Enable: true, FakeEnable: false})
   260  }
   261  
   262  // LookupIPv4 implements dns.IPv4Lookup.
   263  func (s *DNS) LookupIPv4(domain string) ([]net.IP, error) {
   264  	return s.lookupIPInternal(domain, dns.IPOption{IPv4Enable: true, FakeEnable: false})
   265  }
   266  
   267  // LookupIPv6 implements dns.IPv6Lookup.
   268  func (s *DNS) LookupIPv6(domain string) ([]net.IP, error) {
   269  	return s.lookupIPInternal(domain, dns.IPOption{IPv6Enable: true, FakeEnable: false})
   270  }
   271  
   272  func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, error) {
   273  	if domain == "" {
   274  		return nil, newError("empty domain name")
   275  	}
   276  
   277  	// Normalize the FQDN form query
   278  	domain = strings.TrimSuffix(domain, ".")
   279  
   280  	// Static host lookup
   281  	switch addrs := s.hosts.Lookup(domain, option); {
   282  	case addrs == nil: // Domain not recorded in static host
   283  		break
   284  	case len(addrs) == 0: // Domain recorded, but no valid IP returned (e.g. IPv4 address with only IPv6 enabled)
   285  		return nil, dns.ErrEmptyResponse
   286  	case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Domain replacement
   287  		newError("domain replaced: ", domain, " -> ", addrs[0].Domain()).WriteToLog()
   288  		domain = addrs[0].Domain()
   289  	default: // Successfully found ip records in static host
   290  		newError("returning ", len(addrs), " IP(s) for domain ", domain, " -> ", addrs).WriteToLog()
   291  		return toNetIP(addrs)
   292  	}
   293  
   294  	// Name servers lookup
   295  	errs := []error{}
   296  	for _, client := range s.sortClients(domain, option) {
   297  		ips, err := client.QueryIP(s.ctx, domain, option)
   298  		if len(ips) > 0 {
   299  			return ips, nil
   300  		}
   301  		if err != nil {
   302  			errs = append(errs, err)
   303  		}
   304  		if err != dns.ErrEmptyResponse { // ErrEmptyResponse is not seen as failure, so no failed log
   305  			newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
   306  		}
   307  		if err != context.Canceled && err != context.DeadlineExceeded && err != errExpectedIPNonMatch {
   308  			return nil, err // Only continue lookup for certain errors
   309  		}
   310  	}
   311  
   312  	if len(errs) == 0 {
   313  		return nil, dns.ErrEmptyResponse
   314  	}
   315  	return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...))
   316  }
   317  
   318  func (s *DNS) sortClients(domain string, option dns.IPOption) []*Client {
   319  	clients := make([]*Client, 0, len(s.clients))
   320  	clientUsed := make([]bool, len(s.clients))
   321  	clientIdxs := make([]int, 0, len(s.clients))
   322  	domainRules := []string{}
   323  
   324  	// Priority domain matching
   325  	for _, match := range s.domainMatcher.Match(domain) {
   326  		info := s.matcherInfos[match]
   327  		client := s.clients[info.clientIdx]
   328  		domainRule := client.domains[info.domainRuleIdx]
   329  		domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx))
   330  		switch {
   331  		case clientUsed[info.clientIdx]:
   332  			continue
   333  		case !option.FakeEnable && isFakeDNS(client.server):
   334  			continue
   335  		}
   336  		clientUsed[info.clientIdx] = true
   337  		clients = append(clients, client)
   338  		clientIdxs = append(clientIdxs, int(info.clientIdx))
   339  	}
   340  
   341  	// Default round-robin query
   342  	hasDomainMatch := len(clients) > 0
   343  	for idx, client := range s.clients {
   344  		switch {
   345  		case clientUsed[idx]:
   346  			continue
   347  		case !option.FakeEnable && isFakeDNS(client.server):
   348  			continue
   349  		case client.fallbackStrategy == FallbackStrategy_Disabled:
   350  			continue
   351  		case client.fallbackStrategy == FallbackStrategy_DisabledIfAnyMatch && hasDomainMatch:
   352  			continue
   353  		}
   354  		clientUsed[idx] = true
   355  		clients = append(clients, client)
   356  		clientIdxs = append(clientIdxs, idx)
   357  	}
   358  
   359  	if len(domainRules) > 0 {
   360  		newError("domain ", domain, " matches following rules: ", domainRules).AtDebug().WriteToLog()
   361  	}
   362  	if len(clientIdxs) > 0 {
   363  		newError("domain ", domain, " will use DNS in order: ", s.formatClientNames(clientIdxs, option), " ", toReqTypes(option)).AtDebug().WriteToLog()
   364  	}
   365  
   366  	return clients
   367  }
   368  
   369  func (s *DNS) formatClientNames(clientIdxs []int, option dns.IPOption) []string {
   370  	clientNames := make([]string, 0, len(clientIdxs))
   371  	counter := make(map[string]uint, len(clientIdxs))
   372  	for _, clientIdx := range clientIdxs {
   373  		client := s.clients[clientIdx]
   374  		var name string
   375  		if option.With(client.queryStrategy).FakeEnable {
   376  			name = fmt.Sprintf("%s(DNS idx:%d)", client.fakeDNS.Name(), clientIdx)
   377  		} else {
   378  			name = client.Name()
   379  		}
   380  		counter[name]++
   381  		clientNames = append(clientNames, name)
   382  	}
   383  	for idx, clientIdx := range clientIdxs {
   384  		name := clientNames[idx]
   385  		if counter[name] > 1 {
   386  			clientNames[idx] = fmt.Sprintf("%s(DNS idx:%d)", name, clientIdx)
   387  		}
   388  	}
   389  	return clientNames
   390  }
   391  
   392  func init() {
   393  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   394  		return New(ctx, config.(*Config))
   395  	}))
   396  
   397  	common.Must(common.RegisterConfig((*SimplifiedConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   398  		ctx = cfgcommon.NewConfigureLoadingContext(ctx)
   399  
   400  		geoloadername := platform.NewEnvFlag("v2ray.conf.geoloader").GetValue(func() string {
   401  			return "standard"
   402  		})
   403  
   404  		if loader, err := geodata.GetGeoDataLoader(geoloadername); err == nil {
   405  			cfgcommon.SetGeoDataLoader(ctx, loader)
   406  		} else {
   407  			return nil, newError("unable to create geo data loader ").Base(err)
   408  		}
   409  
   410  		cfgEnv := cfgcommon.GetConfigureLoadingEnvironment(ctx)
   411  		geoLoader := cfgEnv.GetGeoLoader()
   412  
   413  		simplifiedConfig := config.(*SimplifiedConfig)
   414  		for _, v := range simplifiedConfig.NameServer {
   415  			for _, geo := range v.Geoip {
   416  				if geo.Code != "" {
   417  					filepath := "geoip.dat"
   418  					if geo.FilePath != "" {
   419  						filepath = geo.FilePath
   420  					} else {
   421  						geo.CountryCode = geo.Code
   422  					}
   423  					var err error
   424  					geo.Cidr, err = geoLoader.LoadIP(filepath, geo.Code)
   425  					if err != nil {
   426  						return nil, newError("unable to load geoip").Base(err)
   427  					}
   428  				}
   429  			}
   430  		}
   431  
   432  		var nameservers []*NameServer
   433  
   434  		for _, v := range simplifiedConfig.NameServer {
   435  			nameserver := &NameServer{
   436  				Address:          v.Address,
   437  				ClientIp:         net.ParseIP(v.ClientIp),
   438  				Tag:              v.Tag,
   439  				QueryStrategy:    v.QueryStrategy,
   440  				CacheStrategy:    v.CacheStrategy,
   441  				FallbackStrategy: v.FallbackStrategy,
   442  				SkipFallback:     v.SkipFallback,
   443  				Geoip:            v.Geoip,
   444  			}
   445  			for _, prioritizedDomain := range v.PrioritizedDomain {
   446  				nameserver.PrioritizedDomain = append(nameserver.PrioritizedDomain, &NameServer_PriorityDomain{
   447  					Type:   prioritizedDomain.Type,
   448  					Domain: prioritizedDomain.Domain,
   449  				})
   450  			}
   451  			nameservers = append(nameservers, nameserver)
   452  		}
   453  
   454  		var statichosts []*HostMapping
   455  
   456  		for _, v := range simplifiedConfig.StaticHosts {
   457  			statichost := &HostMapping{
   458  				Type:          v.Type,
   459  				Domain:        v.Domain,
   460  				ProxiedDomain: v.ProxiedDomain,
   461  			}
   462  			for _, ip := range v.Ip {
   463  				statichost.Ip = append(statichost.Ip, net.ParseIP(ip))
   464  			}
   465  			statichosts = append(statichosts, statichost)
   466  		}
   467  
   468  		fullConfig := &Config{
   469  			StaticHosts:      statichosts,
   470  			NameServer:       nameservers,
   471  			ClientIp:         net.ParseIP(simplifiedConfig.ClientIp),
   472  			Tag:              simplifiedConfig.Tag,
   473  			QueryStrategy:    simplifiedConfig.QueryStrategy,
   474  			CacheStrategy:    simplifiedConfig.CacheStrategy,
   475  			FallbackStrategy: simplifiedConfig.FallbackStrategy,
   476  			// Deprecated flags
   477  			DisableCache:           simplifiedConfig.DisableCache,
   478  			DisableFallback:        simplifiedConfig.DisableFallback,
   479  			DisableFallbackIfMatch: simplifiedConfig.DisableFallbackIfMatch,
   480  		}
   481  		return common.CreateObject(ctx, fullConfig)
   482  	}))
   483  }