github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/app/dns/hosts.go (about)

     1  package dns
     2  
     3  import (
     4  	"github.com/xtls/xray-core/common"
     5  	"github.com/xtls/xray-core/common/net"
     6  	"github.com/xtls/xray-core/common/strmatcher"
     7  	"github.com/xtls/xray-core/features"
     8  	"github.com/xtls/xray-core/features/dns"
     9  )
    10  
    11  // StaticHosts represents static domain-ip mapping in DNS server.
    12  type StaticHosts struct {
    13  	ips      [][]net.Address
    14  	matchers *strmatcher.MatcherGroup
    15  }
    16  
    17  // NewStaticHosts creates a new StaticHosts instance.
    18  func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDomain) (*StaticHosts, error) {
    19  	g := new(strmatcher.MatcherGroup)
    20  	sh := &StaticHosts{
    21  		ips:      make([][]net.Address, len(hosts)+len(legacy)+16),
    22  		matchers: g,
    23  	}
    24  
    25  	if legacy != nil {
    26  		features.PrintDeprecatedFeatureWarning("simple host mapping")
    27  
    28  		for domain, ip := range legacy {
    29  			matcher, err := strmatcher.Full.New(domain)
    30  			common.Must(err)
    31  			id := g.Add(matcher)
    32  
    33  			address := ip.AsAddress()
    34  			if address.Family().IsDomain() {
    35  				return nil, newError("invalid domain address in static hosts: ", address.Domain()).AtWarning()
    36  			}
    37  
    38  			sh.ips[id] = []net.Address{address}
    39  		}
    40  	}
    41  
    42  	for _, mapping := range hosts {
    43  		matcher, err := toStrMatcher(mapping.Type, mapping.Domain)
    44  		if err != nil {
    45  			return nil, newError("failed to create domain matcher").Base(err)
    46  		}
    47  		id := g.Add(matcher)
    48  		ips := make([]net.Address, 0, len(mapping.Ip)+1)
    49  		switch {
    50  		case len(mapping.ProxiedDomain) > 0:
    51  			ips = append(ips, net.DomainAddress(mapping.ProxiedDomain))
    52  		case len(mapping.Ip) > 0:
    53  			for _, ip := range mapping.Ip {
    54  				addr := net.IPAddress(ip)
    55  				if addr == nil {
    56  					return nil, newError("invalid IP address in static hosts: ", ip).AtWarning()
    57  				}
    58  				ips = append(ips, addr)
    59  			}
    60  		default:
    61  			return nil, newError("neither IP address nor proxied domain specified for domain: ", mapping.Domain).AtWarning()
    62  		}
    63  
    64  		sh.ips[id] = ips
    65  	}
    66  
    67  	return sh, nil
    68  }
    69  
    70  func filterIP(ips []net.Address, option dns.IPOption) []net.Address {
    71  	filtered := make([]net.Address, 0, len(ips))
    72  	for _, ip := range ips {
    73  		if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) {
    74  			filtered = append(filtered, ip)
    75  		}
    76  	}
    77  	return filtered
    78  }
    79  
    80  func (h *StaticHosts) lookupInternal(domain string) []net.Address {
    81  	var ips []net.Address
    82  	for _, id := range h.matchers.Match(domain) {
    83  		ips = append(ips, h.ips[id]...)
    84  	}
    85  	return ips
    86  }
    87  
    88  func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) []net.Address {
    89  	switch addrs := h.lookupInternal(domain); {
    90  	case len(addrs) == 0: // Not recorded in static hosts, return nil
    91  		return nil
    92  	case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Try to unwrap domain
    93  		newError("found replaced domain: ", domain, " -> ", addrs[0].Domain(), ". Try to unwrap it").AtDebug().WriteToLog()
    94  		if maxDepth > 0 {
    95  			unwrapped := h.lookup(addrs[0].Domain(), option, maxDepth-1)
    96  			if unwrapped != nil {
    97  				return unwrapped
    98  			}
    99  		}
   100  		return addrs
   101  	default: // IP record found, return a non-nil IP array
   102  		return filterIP(addrs, option)
   103  	}
   104  }
   105  
   106  // Lookup returns IP addresses or proxied domain for the given domain, if exists in this StaticHosts.
   107  func (h *StaticHosts) Lookup(domain string, option dns.IPOption) []net.Address {
   108  	return h.lookup(domain, option, 5)
   109  }