github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/app/dns/dns.go (about) 1 // Package dns is an implementation of core.DNS feature. 2 package dns 3 4 //go:generate go run github.com/xmplusdev/xmcore/common/errors/errorgen 5 6 import ( 7 "context" 8 "fmt" 9 "strings" 10 "sync" 11 12 "github.com/xmplusdev/xmcore/app/router" 13 "github.com/xmplusdev/xmcore/common" 14 "github.com/xmplusdev/xmcore/common/errors" 15 "github.com/xmplusdev/xmcore/common/net" 16 "github.com/xmplusdev/xmcore/common/session" 17 "github.com/xmplusdev/xmcore/common/strmatcher" 18 "github.com/xmplusdev/xmcore/features" 19 "github.com/xmplusdev/xmcore/features/dns" 20 ) 21 22 // DNS is a DNS rely server. 23 type DNS struct { 24 sync.Mutex 25 tag string 26 disableCache bool 27 disableFallback bool 28 disableFallbackIfMatch bool 29 ipOption *dns.IPOption 30 hosts *StaticHosts 31 clients []*Client 32 ctx context.Context 33 domainMatcher strmatcher.IndexMatcher 34 matcherInfos []*DomainMatcherInfo 35 } 36 37 // DomainMatcherInfo contains information attached to index returned by Server.domainMatcher 38 type DomainMatcherInfo struct { 39 clientIdx uint16 40 domainRuleIdx uint16 41 } 42 43 // New creates a new DNS server with given configuration. 44 func New(ctx context.Context, config *Config) (*DNS, error) { 45 var tag string 46 if len(config.Tag) > 0 { 47 tag = config.Tag 48 } else { 49 tag = generateRandomTag() 50 } 51 52 var clientIP net.IP 53 switch len(config.ClientIp) { 54 case 0, net.IPv4len, net.IPv6len: 55 clientIP = net.IP(config.ClientIp) 56 default: 57 return nil, newError("unexpected client IP length ", len(config.ClientIp)) 58 } 59 60 var ipOption *dns.IPOption 61 switch config.QueryStrategy { 62 case QueryStrategy_USE_IP: 63 ipOption = &dns.IPOption{ 64 IPv4Enable: true, 65 IPv6Enable: true, 66 FakeEnable: false, 67 } 68 case QueryStrategy_USE_IP4: 69 ipOption = &dns.IPOption{ 70 IPv4Enable: true, 71 IPv6Enable: false, 72 FakeEnable: false, 73 } 74 case QueryStrategy_USE_IP6: 75 ipOption = &dns.IPOption{ 76 IPv4Enable: false, 77 IPv6Enable: true, 78 FakeEnable: false, 79 } 80 } 81 82 hosts, err := NewStaticHosts(config.StaticHosts, config.Hosts) 83 if err != nil { 84 return nil, newError("failed to create hosts").Base(err) 85 } 86 87 clients := []*Client{} 88 domainRuleCount := 0 89 for _, ns := range config.NameServer { 90 domainRuleCount += len(ns.PrioritizedDomain) 91 } 92 93 // MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1 94 matcherInfos := make([]*DomainMatcherInfo, domainRuleCount+1) 95 domainMatcher := &strmatcher.MatcherGroup{} 96 geoipContainer := router.GeoIPMatcherContainer{} 97 98 for _, endpoint := range config.NameServers { 99 features.PrintDeprecatedFeatureWarning("simple DNS server") 100 client, err := NewSimpleClient(ctx, endpoint, clientIP) 101 if err != nil { 102 return nil, newError("failed to create client").Base(err) 103 } 104 clients = append(clients, client) 105 } 106 107 for _, ns := range config.NameServer { 108 clientIdx := len(clients) 109 updateDomain := func(domainRule strmatcher.Matcher, originalRuleIdx int, matcherInfos []*DomainMatcherInfo) error { 110 midx := domainMatcher.Add(domainRule) 111 matcherInfos[midx] = &DomainMatcherInfo{ 112 clientIdx: uint16(clientIdx), 113 domainRuleIdx: uint16(originalRuleIdx), 114 } 115 return nil 116 } 117 118 myClientIP := clientIP 119 switch len(ns.ClientIp) { 120 case net.IPv4len, net.IPv6len: 121 myClientIP = net.IP(ns.ClientIp) 122 } 123 client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain) 124 if err != nil { 125 return nil, newError("failed to create client").Base(err) 126 } 127 clients = append(clients, client) 128 } 129 130 // If there is no DNS client in config, add a `localhost` DNS client 131 if len(clients) == 0 { 132 clients = append(clients, NewLocalDNSClient()) 133 } 134 135 return &DNS{ 136 tag: tag, 137 hosts: hosts, 138 ipOption: ipOption, 139 clients: clients, 140 ctx: ctx, 141 domainMatcher: domainMatcher, 142 matcherInfos: matcherInfos, 143 disableCache: config.DisableCache, 144 disableFallback: config.DisableFallback, 145 disableFallbackIfMatch: config.DisableFallbackIfMatch, 146 }, nil 147 } 148 149 // Type implements common.HasType. 150 func (*DNS) Type() interface{} { 151 return dns.ClientType() 152 } 153 154 // Start implements common.Runnable. 155 func (s *DNS) Start() error { 156 return nil 157 } 158 159 // Close implements common.Closable. 160 func (s *DNS) Close() error { 161 return nil 162 } 163 164 // IsOwnLink implements proxy.dns.ownLinkVerifier 165 func (s *DNS) IsOwnLink(ctx context.Context) bool { 166 inbound := session.InboundFromContext(ctx) 167 return inbound != nil && inbound.Tag == s.tag 168 } 169 170 // LookupIP implements dns.Client. 171 func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, error) { 172 if domain == "" { 173 return nil, newError("empty domain name") 174 } 175 176 option.IPv4Enable = option.IPv4Enable && s.ipOption.IPv4Enable 177 option.IPv6Enable = option.IPv6Enable && s.ipOption.IPv6Enable 178 179 if !option.IPv4Enable && !option.IPv6Enable { 180 return nil, dns.ErrEmptyResponse 181 } 182 183 // Normalize the FQDN form query 184 domain = strings.TrimSuffix(domain, ".") 185 186 // Static host lookup 187 switch addrs := s.hosts.Lookup(domain, option); { 188 case addrs == nil: // Domain not recorded in static host 189 break 190 case len(addrs) == 0: // Domain recorded, but no valid IP returned (e.g. IPv4 address with only IPv6 enabled) 191 return nil, dns.ErrEmptyResponse 192 case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Domain replacement 193 newError("domain replaced: ", domain, " -> ", addrs[0].Domain()).WriteToLog() 194 domain = addrs[0].Domain() 195 default: // Successfully found ip records in static host 196 newError("returning ", len(addrs), " IP(s) for domain ", domain, " -> ", addrs).WriteToLog() 197 return toNetIP(addrs) 198 } 199 200 // Name servers lookup 201 errs := []error{} 202 ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag}) 203 for _, client := range s.sortClients(domain) { 204 if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") { 205 newError("skip DNS resolution for domain ", domain, " at server ", client.Name()).AtDebug().WriteToLog() 206 continue 207 } 208 ips, err := client.QueryIP(ctx, domain, option, s.disableCache) 209 if len(ips) > 0 { 210 return ips, nil 211 } 212 if err != nil { 213 newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog() 214 errs = append(errs, err) 215 } 216 // 5 for RcodeRefused in miekg/dns, hardcode to reduce binary size 217 if err != context.Canceled && err != context.DeadlineExceeded && err != errExpectedIPNonMatch && err != dns.ErrEmptyResponse && dns.RCodeFromError(err) != 5 { 218 return nil, err 219 } 220 } 221 222 return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...)) 223 } 224 225 // LookupHosts implements dns.HostsLookup. 226 func (s *DNS) LookupHosts(domain string) *net.Address { 227 domain = strings.TrimSuffix(domain, ".") 228 if domain == "" { 229 return nil 230 } 231 // Normalize the FQDN form query 232 addrs := s.hosts.Lookup(domain, *s.ipOption) 233 if len(addrs) > 0 { 234 newError("domain replaced: ", domain, " -> ", addrs[0].String()).AtInfo().WriteToLog() 235 return &addrs[0] 236 } 237 238 return nil 239 } 240 241 // GetIPOption implements ClientWithIPOption. 242 func (s *DNS) GetIPOption() *dns.IPOption { 243 return s.ipOption 244 } 245 246 // SetQueryOption implements ClientWithIPOption. 247 func (s *DNS) SetQueryOption(isIPv4Enable, isIPv6Enable bool) { 248 s.ipOption.IPv4Enable = isIPv4Enable 249 s.ipOption.IPv6Enable = isIPv6Enable 250 } 251 252 // SetFakeDNSOption implements ClientWithIPOption. 253 func (s *DNS) SetFakeDNSOption(isFakeEnable bool) { 254 s.ipOption.FakeEnable = isFakeEnable 255 } 256 257 func (s *DNS) sortClients(domain string) []*Client { 258 clients := make([]*Client, 0, len(s.clients)) 259 clientUsed := make([]bool, len(s.clients)) 260 clientNames := make([]string, 0, len(s.clients)) 261 domainRules := []string{} 262 263 // Priority domain matching 264 hasMatch := false 265 for _, match := range s.domainMatcher.Match(domain) { 266 info := s.matcherInfos[match] 267 client := s.clients[info.clientIdx] 268 domainRule := client.domains[info.domainRuleIdx] 269 domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx)) 270 if clientUsed[info.clientIdx] { 271 continue 272 } 273 clientUsed[info.clientIdx] = true 274 clients = append(clients, client) 275 clientNames = append(clientNames, client.Name()) 276 hasMatch = true 277 } 278 279 if !(s.disableFallback || s.disableFallbackIfMatch && hasMatch) { 280 // Default round-robin query 281 for idx, client := range s.clients { 282 if clientUsed[idx] || client.skipFallback { 283 continue 284 } 285 clientUsed[idx] = true 286 clients = append(clients, client) 287 clientNames = append(clientNames, client.Name()) 288 } 289 } 290 291 if len(domainRules) > 0 { 292 newError("domain ", domain, " matches following rules: ", domainRules).AtDebug().WriteToLog() 293 } 294 if len(clientNames) > 0 { 295 newError("domain ", domain, " will use DNS in order: ", clientNames).AtDebug().WriteToLog() 296 } 297 298 if len(clients) == 0 { 299 clients = append(clients, s.clients[0]) 300 clientNames = append(clientNames, s.clients[0].Name()) 301 newError("domain ", domain, " will use the first DNS: ", clientNames).AtDebug().WriteToLog() 302 } 303 304 return clients 305 } 306 307 func init() { 308 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 309 return New(ctx, config.(*Config)) 310 })) 311 }