github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/route/router_dns.go (about)

     1  package route
     2  
     3  import (
     4  	"context"
     5  	"net/netip"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/inazumav/sing-box/adapter"
    10  	C "github.com/inazumav/sing-box/constant"
    11  	"github.com/inazumav/sing-box/log"
    12  	"github.com/sagernet/sing-dns"
    13  	"github.com/sagernet/sing/common/cache"
    14  	E "github.com/sagernet/sing/common/exceptions"
    15  	F "github.com/sagernet/sing/common/format"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  
    18  	mDNS "github.com/miekg/dns"
    19  )
    20  
    21  type DNSReverseMapping struct {
    22  	cache *cache.LruCache[netip.Addr, string]
    23  }
    24  
    25  func NewDNSReverseMapping() *DNSReverseMapping {
    26  	return &DNSReverseMapping{
    27  		cache: cache.New[netip.Addr, string](),
    28  	}
    29  }
    30  
    31  func (m *DNSReverseMapping) Save(address netip.Addr, domain string, ttl int) {
    32  	m.cache.StoreWithExpire(address, domain, time.Now().Add(time.Duration(ttl)*time.Second))
    33  }
    34  
    35  func (m *DNSReverseMapping) Query(address netip.Addr) (string, bool) {
    36  	domain, loaded := m.cache.Load(address)
    37  	return domain, loaded
    38  }
    39  
    40  func (r *Router) matchDNS(ctx context.Context) (context.Context, dns.Transport, dns.DomainStrategy) {
    41  	metadata := adapter.ContextFrom(ctx)
    42  	if metadata == nil {
    43  		panic("no context")
    44  	}
    45  	for i, rule := range r.dnsRules {
    46  		if rule.Match(metadata) {
    47  			detour := rule.Outbound()
    48  			transport, loaded := r.transportMap[detour]
    49  			if !loaded {
    50  				r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour)
    51  				continue
    52  			}
    53  			if _, isFakeIP := transport.(adapter.FakeIPTransport); isFakeIP && metadata.FakeIP {
    54  				continue
    55  			}
    56  			r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour)
    57  			if rule.DisableCache() {
    58  				ctx = dns.ContextWithDisableCache(ctx, true)
    59  			}
    60  			if rewriteTTL := rule.RewriteTTL(); rewriteTTL != nil {
    61  				ctx = dns.ContextWithRewriteTTL(ctx, *rewriteTTL)
    62  			}
    63  			if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded {
    64  				return ctx, transport, domainStrategy
    65  			} else {
    66  				return ctx, transport, r.defaultDomainStrategy
    67  			}
    68  		}
    69  	}
    70  	if domainStrategy, dsLoaded := r.transportDomainStrategy[r.defaultTransport]; dsLoaded {
    71  		return ctx, r.defaultTransport, domainStrategy
    72  	} else {
    73  		return ctx, r.defaultTransport, r.defaultDomainStrategy
    74  	}
    75  }
    76  
    77  func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
    78  	if len(message.Question) > 0 {
    79  		r.dnsLogger.DebugContext(ctx, "exchange ", formatQuestion(message.Question[0].String()))
    80  	}
    81  	var (
    82  		response *mDNS.Msg
    83  		cached   bool
    84  		err      error
    85  	)
    86  	response, cached = r.dnsClient.ExchangeCache(ctx, message)
    87  	if !cached {
    88  		ctx, metadata := adapter.AppendContext(ctx)
    89  		if len(message.Question) > 0 {
    90  			metadata.QueryType = message.Question[0].Qtype
    91  			switch metadata.QueryType {
    92  			case mDNS.TypeA:
    93  				metadata.IPVersion = 4
    94  			case mDNS.TypeAAAA:
    95  				metadata.IPVersion = 6
    96  			}
    97  			metadata.Domain = fqdnToDomain(message.Question[0].Name)
    98  		}
    99  		ctx, transport, strategy := r.matchDNS(ctx)
   100  		ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout)
   101  		defer cancel()
   102  		response, err = r.dnsClient.Exchange(ctx, transport, message, strategy)
   103  		if err != nil && len(message.Question) > 0 {
   104  			r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", formatQuestion(message.Question[0].String())))
   105  		}
   106  	}
   107  	if len(message.Question) > 0 && response != nil {
   108  		LogDNSAnswers(r.dnsLogger, ctx, message.Question[0].Name, response.Answer)
   109  	}
   110  	if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 {
   111  		for _, answer := range response.Answer {
   112  			switch record := answer.(type) {
   113  			case *mDNS.A:
   114  				r.dnsReverseMapping.Save(M.AddrFromIP(record.A), fqdnToDomain(record.Hdr.Name), int(record.Hdr.Ttl))
   115  			case *mDNS.AAAA:
   116  				r.dnsReverseMapping.Save(M.AddrFromIP(record.AAAA), fqdnToDomain(record.Hdr.Name), int(record.Hdr.Ttl))
   117  			}
   118  		}
   119  	}
   120  	return response, err
   121  }
   122  
   123  func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
   124  	r.dnsLogger.DebugContext(ctx, "lookup domain ", domain)
   125  	ctx, metadata := adapter.AppendContext(ctx)
   126  	metadata.Domain = domain
   127  	ctx, transport, transportStrategy := r.matchDNS(ctx)
   128  	if strategy == dns.DomainStrategyAsIS {
   129  		strategy = transportStrategy
   130  	}
   131  	ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout)
   132  	defer cancel()
   133  	addrs, err := r.dnsClient.Lookup(ctx, transport, domain, strategy)
   134  	if len(addrs) > 0 {
   135  		r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(addrs), " "))
   136  	} else {
   137  		r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
   138  		if err == nil {
   139  			err = dns.RCodeNameError
   140  		}
   141  	}
   142  	return addrs, err
   143  }
   144  
   145  func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) {
   146  	return r.Lookup(ctx, domain, dns.DomainStrategyAsIS)
   147  }
   148  
   149  func (r *Router) ClearDNSCache() {
   150  	r.dnsClient.ClearCache()
   151  	if r.platformInterface != nil {
   152  		r.platformInterface.ClearDNSCache()
   153  	}
   154  }
   155  
   156  func LogDNSAnswers(logger log.ContextLogger, ctx context.Context, domain string, answers []mDNS.RR) {
   157  	for _, answer := range answers {
   158  		logger.InfoContext(ctx, "exchanged ", domain, " ", mDNS.Type(answer.Header().Rrtype).String(), " ", formatQuestion(answer.String()))
   159  	}
   160  }
   161  
   162  func fqdnToDomain(fqdn string) string {
   163  	if mDNS.IsFqdn(fqdn) {
   164  		return fqdn[:len(fqdn)-1]
   165  	}
   166  	return fqdn
   167  }
   168  
   169  func formatQuestion(string string) string {
   170  	if strings.HasPrefix(string, ";") {
   171  		string = string[1:]
   172  	}
   173  	string = strings.ReplaceAll(string, "\t", " ")
   174  	for strings.Contains(string, "  ") {
   175  		string = strings.ReplaceAll(string, "  ", " ")
   176  	}
   177  	return string
   178  }