github.com/sagernet/sing-box@v1.9.0-rc.20/route/router_dns.go (about) 1 package route 2 3 import ( 4 "context" 5 "errors" 6 "net/netip" 7 "strings" 8 "time" 9 10 "github.com/sagernet/sing-box/adapter" 11 C "github.com/sagernet/sing-box/constant" 12 "github.com/sagernet/sing-dns" 13 "github.com/sagernet/sing/common/cache" 14 E "github.com/sagernet/sing/common/exceptions" 15 F "github.com/sagernet/sing/common/format" 16 M "github.com/sagernet/sing/common/metadata" 17 18 mDNS "github.com/miekg/dns" 19 ) 20 21 type DNSReverseMapping struct { 22 cache *cache.LruCache[netip.Addr, string] 23 } 24 25 func NewDNSReverseMapping() *DNSReverseMapping { 26 return &DNSReverseMapping{ 27 cache: cache.New[netip.Addr, string](), 28 } 29 } 30 31 func (m *DNSReverseMapping) Save(address netip.Addr, domain string, ttl int) { 32 m.cache.StoreWithExpire(address, domain, time.Now().Add(time.Duration(ttl)*time.Second)) 33 } 34 35 func (m *DNSReverseMapping) Query(address netip.Addr) (string, bool) { 36 domain, loaded := m.cache.Load(address) 37 return domain, loaded 38 } 39 40 func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, index int) (context.Context, dns.Transport, dns.DomainStrategy, adapter.DNSRule, int) { 41 metadata := adapter.ContextFrom(ctx) 42 if metadata == nil { 43 panic("no context") 44 } 45 if index < len(r.dnsRules) { 46 dnsRules := r.dnsRules 47 if index != -1 { 48 dnsRules = dnsRules[index+1:] 49 } 50 for currentRuleIndex, rule := range dnsRules { 51 metadata.ResetRuleCache() 52 if rule.Match(metadata) { 53 detour := rule.Outbound() 54 transport, loaded := r.transportMap[detour] 55 if !loaded { 56 r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour) 57 continue 58 } 59 _, isFakeIP := transport.(adapter.FakeIPTransport) 60 if isFakeIP && !allowFakeIP { 61 continue 62 } 63 ruleIndex := currentRuleIndex 64 if index != -1 { 65 ruleIndex += index + 1 66 } 67 r.dnsLogger.DebugContext(ctx, "match[", ruleIndex, "] ", rule.String(), " => ", detour) 68 if isFakeIP || rule.DisableCache() { 69 ctx = dns.ContextWithDisableCache(ctx, true) 70 } 71 if rewriteTTL := rule.RewriteTTL(); rewriteTTL != nil { 72 ctx = dns.ContextWithRewriteTTL(ctx, *rewriteTTL) 73 } 74 if clientSubnet := rule.ClientSubnet(); clientSubnet != nil { 75 ctx = dns.ContextWithClientSubnet(ctx, *clientSubnet) 76 } 77 if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded { 78 return ctx, transport, domainStrategy, rule, ruleIndex 79 } else { 80 return ctx, transport, r.defaultDomainStrategy, rule, ruleIndex 81 } 82 } 83 } 84 } 85 if domainStrategy, dsLoaded := r.transportDomainStrategy[r.defaultTransport]; dsLoaded { 86 return ctx, r.defaultTransport, domainStrategy, nil, -1 87 } else { 88 return ctx, r.defaultTransport, r.defaultDomainStrategy, nil, -1 89 } 90 } 91 92 func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { 93 if len(message.Question) > 0 { 94 r.dnsLogger.DebugContext(ctx, "exchange ", formatQuestion(message.Question[0].String())) 95 } 96 var ( 97 response *mDNS.Msg 98 cached bool 99 transport dns.Transport 100 err error 101 ) 102 response, cached = r.dnsClient.ExchangeCache(ctx, message) 103 if !cached { 104 var metadata *adapter.InboundContext 105 ctx, metadata = adapter.AppendContext(ctx) 106 if len(message.Question) > 0 { 107 metadata.QueryType = message.Question[0].Qtype 108 switch metadata.QueryType { 109 case mDNS.TypeA: 110 metadata.IPVersion = 4 111 case mDNS.TypeAAAA: 112 metadata.IPVersion = 6 113 } 114 metadata.Domain = fqdnToDomain(message.Question[0].Name) 115 } 116 var ( 117 strategy dns.DomainStrategy 118 rule adapter.DNSRule 119 ruleIndex int 120 ) 121 ruleIndex = -1 122 for { 123 var ( 124 dnsCtx context.Context 125 cancel context.CancelFunc 126 addressLimit bool 127 ) 128 129 dnsCtx, transport, strategy, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex) 130 dnsCtx, cancel = context.WithTimeout(dnsCtx, C.DNSTimeout) 131 if rule != nil && rule.WithAddressLimit() && isAddressQuery(message) { 132 addressLimit = true 133 response, err = r.dnsClient.ExchangeWithResponseCheck(dnsCtx, transport, message, strategy, func(response *mDNS.Msg) bool { 134 metadata.DestinationAddresses, _ = dns.MessageToAddresses(response) 135 return rule.MatchAddressLimit(metadata) 136 }) 137 } else { 138 addressLimit = false 139 response, err = r.dnsClient.Exchange(dnsCtx, transport, message, strategy) 140 } 141 cancel() 142 var rejected bool 143 if err != nil { 144 if errors.Is(err, dns.ErrResponseRejectedCached) { 145 rejected = true 146 r.dnsLogger.DebugContext(ctx, E.Cause(err, "response rejected for ", formatQuestion(message.Question[0].String())), " (cached)") 147 } else if errors.Is(err, dns.ErrResponseRejected) { 148 rejected = true 149 r.dnsLogger.DebugContext(ctx, E.Cause(err, "response rejected for ", formatQuestion(message.Question[0].String()))) 150 } else if len(message.Question) > 0 { 151 r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", formatQuestion(message.Question[0].String()))) 152 } else { 153 r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for <empty query>")) 154 } 155 } 156 if addressLimit && rejected { 157 continue 158 } 159 break 160 } 161 } 162 if err != nil { 163 return nil, err 164 } 165 if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 { 166 if _, isFakeIP := transport.(adapter.FakeIPTransport); !isFakeIP { 167 for _, answer := range response.Answer { 168 switch record := answer.(type) { 169 case *mDNS.A: 170 r.dnsReverseMapping.Save(M.AddrFromIP(record.A), fqdnToDomain(record.Hdr.Name), int(record.Hdr.Ttl)) 171 case *mDNS.AAAA: 172 r.dnsReverseMapping.Save(M.AddrFromIP(record.AAAA), fqdnToDomain(record.Hdr.Name), int(record.Hdr.Ttl)) 173 } 174 } 175 } 176 } 177 return response, nil 178 } 179 180 func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) { 181 var ( 182 responseAddrs []netip.Addr 183 cached bool 184 err error 185 ) 186 responseAddrs, cached = r.dnsClient.LookupCache(ctx, domain, strategy) 187 if cached { 188 return responseAddrs, nil 189 } 190 r.dnsLogger.DebugContext(ctx, "lookup domain ", domain) 191 ctx, metadata := adapter.AppendContext(ctx) 192 metadata.Domain = domain 193 var ( 194 transport dns.Transport 195 transportStrategy dns.DomainStrategy 196 rule adapter.DNSRule 197 ruleIndex int 198 ) 199 ruleIndex = -1 200 for { 201 var ( 202 dnsCtx context.Context 203 cancel context.CancelFunc 204 addressLimit bool 205 ) 206 metadata.ResetRuleCache() 207 metadata.DestinationAddresses = nil 208 dnsCtx, transport, transportStrategy, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex) 209 if strategy == dns.DomainStrategyAsIS { 210 strategy = transportStrategy 211 } 212 dnsCtx, cancel = context.WithTimeout(dnsCtx, C.DNSTimeout) 213 if rule != nil && rule.WithAddressLimit() { 214 addressLimit = true 215 responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, strategy, func(responseAddrs []netip.Addr) bool { 216 metadata.DestinationAddresses = responseAddrs 217 return rule.MatchAddressLimit(metadata) 218 }) 219 } else { 220 addressLimit = false 221 responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, strategy) 222 } 223 cancel() 224 if err != nil { 225 if errors.Is(err, dns.ErrResponseRejectedCached) { 226 r.dnsLogger.DebugContext(ctx, "response rejected for ", domain, " (cached)") 227 } else if errors.Is(err, dns.ErrResponseRejected) { 228 r.dnsLogger.DebugContext(ctx, "response rejected for ", domain) 229 } else { 230 r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain)) 231 } 232 } else if len(responseAddrs) == 0 { 233 r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result") 234 err = dns.RCodeNameError 235 } 236 if !addressLimit || err == nil { 237 break 238 } 239 } 240 if len(responseAddrs) > 0 { 241 r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " ")) 242 } 243 return responseAddrs, err 244 } 245 246 func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) { 247 return r.Lookup(ctx, domain, dns.DomainStrategyAsIS) 248 } 249 250 func (r *Router) ClearDNSCache() { 251 r.dnsClient.ClearCache() 252 if r.platformInterface != nil { 253 r.platformInterface.ClearDNSCache() 254 } 255 } 256 257 func isAddressQuery(message *mDNS.Msg) bool { 258 for _, question := range message.Question { 259 if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA { 260 return true 261 } 262 } 263 return false 264 } 265 266 func fqdnToDomain(fqdn string) string { 267 if mDNS.IsFqdn(fqdn) { 268 return fqdn[:len(fqdn)-1] 269 } 270 return fqdn 271 } 272 273 func formatQuestion(string string) string { 274 if strings.HasPrefix(string, ";") { 275 string = string[1:] 276 } 277 string = strings.ReplaceAll(string, "\t", " ") 278 for strings.Contains(string, " ") { 279 string = strings.ReplaceAll(string, " ", " ") 280 } 281 return string 282 }