github.com/xraypb/xray-core@v1.6.6/app/dns/dns.go (about) 1 // Package dns is an implementation of core.DNS feature. 2 package dns 3 4 //go:generate go run github.com/xraypb/xray-core/common/errors/errorgen 5 6 import ( 7 "context" 8 "fmt" 9 "strings" 10 "sync" 11 12 "github.com/xraypb/xray-core/app/router" 13 "github.com/xraypb/xray-core/common" 14 "github.com/xraypb/xray-core/common/errors" 15 "github.com/xraypb/xray-core/common/net" 16 "github.com/xraypb/xray-core/common/session" 17 "github.com/xraypb/xray-core/common/strmatcher" 18 "github.com/xraypb/xray-core/features" 19 "github.com/xraypb/xray-core/features/dns" 20 ) 21 22 // DNS is a DNS rely server. 23 type DNS struct { 24 sync.Mutex 25 tag string 26 disableCache bool 27 disableFallback bool 28 disableFallbackIfMatch bool 29 ipOption *dns.IPOption 30 hosts *StaticHosts 31 clients []*Client 32 ctx context.Context 33 domainMatcher strmatcher.IndexMatcher 34 matcherInfos []*DomainMatcherInfo 35 } 36 37 // DomainMatcherInfo contains information attached to index returned by Server.domainMatcher 38 type DomainMatcherInfo struct { 39 clientIdx uint16 40 domainRuleIdx uint16 41 } 42 43 // New creates a new DNS server with given configuration. 44 func New(ctx context.Context, config *Config) (*DNS, error) { 45 var tag string 46 if len(config.Tag) > 0 { 47 tag = config.Tag 48 } else { 49 tag = generateRandomTag() 50 } 51 52 var clientIP net.IP 53 switch len(config.ClientIp) { 54 case 0, net.IPv4len, net.IPv6len: 55 clientIP = net.IP(config.ClientIp) 56 default: 57 return nil, newError("unexpected client IP length ", len(config.ClientIp)) 58 } 59 60 var ipOption *dns.IPOption 61 switch config.QueryStrategy { 62 case QueryStrategy_USE_IP: 63 ipOption = &dns.IPOption{ 64 IPv4Enable: true, 65 IPv6Enable: true, 66 FakeEnable: false, 67 } 68 case QueryStrategy_USE_IP4: 69 ipOption = &dns.IPOption{ 70 IPv4Enable: true, 71 IPv6Enable: false, 72 FakeEnable: false, 73 } 74 case QueryStrategy_USE_IP6: 75 ipOption = &dns.IPOption{ 76 IPv4Enable: false, 77 IPv6Enable: true, 78 FakeEnable: false, 79 } 80 } 81 82 hosts, err := NewStaticHosts(config.StaticHosts, config.Hosts) 83 if err != nil { 84 return nil, newError("failed to create hosts").Base(err) 85 } 86 87 clients := []*Client{} 88 domainRuleCount := 0 89 for _, ns := range config.NameServer { 90 domainRuleCount += len(ns.PrioritizedDomain) 91 } 92 93 // MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1 94 matcherInfos := make([]*DomainMatcherInfo, domainRuleCount+1) 95 domainMatcher := &strmatcher.MatcherGroup{} 96 geoipContainer := router.GeoIPMatcherContainer{} 97 98 for _, endpoint := range config.NameServers { 99 features.PrintDeprecatedFeatureWarning("simple DNS server") 100 client, err := NewSimpleClient(ctx, endpoint, clientIP) 101 if err != nil { 102 return nil, newError("failed to create client").Base(err) 103 } 104 clients = append(clients, client) 105 } 106 107 for _, ns := range config.NameServer { 108 clientIdx := len(clients) 109 updateDomain := func(domainRule strmatcher.Matcher, originalRuleIdx int, matcherInfos []*DomainMatcherInfo) error { 110 midx := domainMatcher.Add(domainRule) 111 matcherInfos[midx] = &DomainMatcherInfo{ 112 clientIdx: uint16(clientIdx), 113 domainRuleIdx: uint16(originalRuleIdx), 114 } 115 return nil 116 } 117 118 myClientIP := clientIP 119 switch len(ns.ClientIp) { 120 case net.IPv4len, net.IPv6len: 121 myClientIP = net.IP(ns.ClientIp) 122 } 123 client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain) 124 if err != nil { 125 return nil, newError("failed to create client").Base(err) 126 } 127 clients = append(clients, client) 128 } 129 130 // If there is no DNS client in config, add a `localhost` DNS client 131 if len(clients) == 0 { 132 clients = append(clients, NewLocalDNSClient()) 133 } 134 135 return &DNS{ 136 tag: tag, 137 hosts: hosts, 138 ipOption: ipOption, 139 clients: clients, 140 ctx: ctx, 141 domainMatcher: domainMatcher, 142 matcherInfos: matcherInfos, 143 disableCache: config.DisableCache, 144 disableFallback: config.DisableFallback, 145 disableFallbackIfMatch: config.DisableFallbackIfMatch, 146 }, nil 147 } 148 149 // Type implements common.HasType. 150 func (*DNS) Type() interface{} { 151 return dns.ClientType() 152 } 153 154 // Start implements common.Runnable. 155 func (s *DNS) Start() error { 156 return nil 157 } 158 159 // Close implements common.Closable. 160 func (s *DNS) Close() error { 161 return nil 162 } 163 164 // IsOwnLink implements proxy.dns.ownLinkVerifier 165 func (s *DNS) IsOwnLink(ctx context.Context) bool { 166 inbound := session.InboundFromContext(ctx) 167 return inbound != nil && inbound.Tag == s.tag 168 } 169 170 // LookupIP implements dns.Client. 171 func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, error) { 172 if domain == "" { 173 return nil, newError("empty domain name") 174 } 175 176 option.IPv4Enable = option.IPv4Enable && s.ipOption.IPv4Enable 177 option.IPv6Enable = option.IPv6Enable && s.ipOption.IPv6Enable 178 179 if !option.IPv4Enable && !option.IPv6Enable { 180 return nil, dns.ErrEmptyResponse 181 } 182 183 // Normalize the FQDN form query 184 if strings.HasSuffix(domain, ".") { 185 domain = domain[:len(domain)-1] 186 } 187 188 // Static host lookup 189 switch addrs := s.hosts.Lookup(domain, option); { 190 case addrs == nil: // Domain not recorded in static host 191 break 192 case len(addrs) == 0: // Domain recorded, but no valid IP returned (e.g. IPv4 address with only IPv6 enabled) 193 return nil, dns.ErrEmptyResponse 194 case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Domain replacement 195 newError("domain replaced: ", domain, " -> ", addrs[0].Domain()).WriteToLog() 196 domain = addrs[0].Domain() 197 default: // Successfully found ip records in static host 198 newError("returning ", len(addrs), " IP(s) for domain ", domain, " -> ", addrs).WriteToLog() 199 return toNetIP(addrs) 200 } 201 202 // Name servers lookup 203 errs := []error{} 204 ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag}) 205 for _, client := range s.sortClients(domain) { 206 if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") { 207 newError("skip DNS resolution for domain ", domain, " at server ", client.Name()).AtDebug().WriteToLog() 208 continue 209 } 210 ips, err := client.QueryIP(ctx, domain, option, s.disableCache) 211 if len(ips) > 0 { 212 return ips, nil 213 } 214 if err != nil { 215 newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog() 216 errs = append(errs, err) 217 } 218 if err != context.Canceled && err != context.DeadlineExceeded && err != errExpectedIPNonMatch { 219 return nil, err 220 } 221 } 222 223 return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...)) 224 } 225 226 // LookupHosts implements dns.HostsLookup. 227 func (s *DNS) LookupHosts(domain string) *net.Address { 228 domain = strings.TrimSuffix(domain, ".") 229 if domain == "" { 230 return nil 231 } 232 // Normalize the FQDN form query 233 addrs := s.hosts.Lookup(domain, *s.ipOption) 234 if len(addrs) > 0 { 235 newError("domain replaced: ", domain, " -> ", addrs[0].String()).AtInfo().WriteToLog() 236 return &addrs[0] 237 } 238 239 return nil 240 } 241 242 // GetIPOption implements ClientWithIPOption. 243 func (s *DNS) GetIPOption() *dns.IPOption { 244 return s.ipOption 245 } 246 247 // SetQueryOption implements ClientWithIPOption. 248 func (s *DNS) SetQueryOption(isIPv4Enable, isIPv6Enable bool) { 249 s.ipOption.IPv4Enable = isIPv4Enable 250 s.ipOption.IPv6Enable = isIPv6Enable 251 } 252 253 // SetFakeDNSOption implements ClientWithIPOption. 254 func (s *DNS) SetFakeDNSOption(isFakeEnable bool) { 255 s.ipOption.FakeEnable = isFakeEnable 256 } 257 258 func (s *DNS) sortClients(domain string) []*Client { 259 clients := make([]*Client, 0, len(s.clients)) 260 clientUsed := make([]bool, len(s.clients)) 261 clientNames := make([]string, 0, len(s.clients)) 262 domainRules := []string{} 263 264 // Priority domain matching 265 hasMatch := false 266 for _, match := range s.domainMatcher.Match(domain) { 267 info := s.matcherInfos[match] 268 client := s.clients[info.clientIdx] 269 domainRule := client.domains[info.domainRuleIdx] 270 domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx)) 271 if clientUsed[info.clientIdx] { 272 continue 273 } 274 clientUsed[info.clientIdx] = true 275 clients = append(clients, client) 276 clientNames = append(clientNames, client.Name()) 277 hasMatch = true 278 } 279 280 if !(s.disableFallback || s.disableFallbackIfMatch && hasMatch) { 281 // Default round-robin query 282 for idx, client := range s.clients { 283 if clientUsed[idx] || client.skipFallback { 284 continue 285 } 286 clientUsed[idx] = true 287 clients = append(clients, client) 288 clientNames = append(clientNames, client.Name()) 289 } 290 } 291 292 if len(domainRules) > 0 { 293 newError("domain ", domain, " matches following rules: ", domainRules).AtDebug().WriteToLog() 294 } 295 if len(clientNames) > 0 { 296 newError("domain ", domain, " will use DNS in order: ", clientNames).AtDebug().WriteToLog() 297 } 298 299 if len(clients) == 0 { 300 clients = append(clients, s.clients[0]) 301 clientNames = append(clientNames, s.clients[0].Name()) 302 newError("domain ", domain, " will use the first DNS: ", clientNames).AtDebug().WriteToLog() 303 } 304 305 return clients 306 } 307 308 func init() { 309 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 310 return New(ctx, config.(*Config)) 311 })) 312 }