github.com/yaling888/clash@v1.53.0/dns/middleware.go (about)

     1  package dns
     2  
     3  import (
     4  	"net/netip"
     5  	"strings"
     6  	"time"
     7  
     8  	D "github.com/miekg/dns"
     9  	"github.com/phuslu/log"
    10  
    11  	"github.com/yaling888/clash/common/cache"
    12  	"github.com/yaling888/clash/component/fakeip"
    13  	"github.com/yaling888/clash/component/trie"
    14  	C "github.com/yaling888/clash/constant"
    15  	"github.com/yaling888/clash/context"
    16  )
    17  
    18  type (
    19  	handler    func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
    20  	middleware func(next handler) handler
    21  )
    22  
    23  func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[netip.Addr, string]) middleware {
    24  	return func(next handler) handler {
    25  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
    26  			q := r.Question[0]
    27  
    28  			if !isIPRequest(q) {
    29  				return next(ctx, r)
    30  			}
    31  
    32  			host := strings.TrimRight(q.Name, ".")
    33  
    34  			record := hosts.Search(host)
    35  			if record == nil {
    36  				return next(ctx, r)
    37  			}
    38  
    39  			ip := record.Data
    40  			msg := r.Copy()
    41  
    42  			if ip.Is4() && q.Qtype == D.TypeA {
    43  				rr := &D.A{}
    44  				rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 3}
    45  				rr.A = ip.AsSlice()
    46  
    47  				msg.Answer = []D.RR{rr}
    48  			} else if ip.Is6() && q.Qtype == D.TypeAAAA {
    49  				rr := &D.AAAA{}
    50  				rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 3}
    51  				rr.AAAA = ip.AsSlice()
    52  
    53  				msg.Answer = []D.RR{rr}
    54  			} else {
    55  				return next(ctx, r)
    56  			}
    57  
    58  			if mapping != nil {
    59  				mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*3))
    60  			}
    61  
    62  			ctx.SetType(context.DNSTypeHost)
    63  			msg.SetRcode(r, D.RcodeSuccess)
    64  			msg.Authoritative = true
    65  			msg.RecursionAvailable = true
    66  
    67  			return msg, nil
    68  		}
    69  	}
    70  }
    71  
    72  func withMapping(mapping *cache.LruCache[netip.Addr, string]) middleware {
    73  	return func(next handler) handler {
    74  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
    75  			q := r.Question[0]
    76  
    77  			if !isIPRequest(q) {
    78  				return next(ctx, r)
    79  			}
    80  
    81  			msg, err := next(ctx, r)
    82  			if err != nil {
    83  				return nil, err
    84  			}
    85  
    86  			host := strings.TrimRight(q.Name, ".")
    87  
    88  			for _, ans := range msg.Answer {
    89  				var ip netip.Addr
    90  				var ttl uint32
    91  
    92  				switch a := ans.(type) {
    93  				case *D.A:
    94  					ip, _ = netip.AddrFromSlice(a.A.To4())
    95  					ttl = a.Hdr.Ttl
    96  					if !ip.IsGlobalUnicast() {
    97  						continue
    98  					}
    99  				case *D.AAAA:
   100  					ip, _ = netip.AddrFromSlice(a.AAAA)
   101  					ttl = a.Hdr.Ttl
   102  					if !ip.IsGlobalUnicast() {
   103  						continue
   104  					}
   105  				default:
   106  					continue
   107  				}
   108  
   109  				ttl = max(ttl, 1)
   110  				mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*time.Duration(ttl)))
   111  			}
   112  
   113  			return msg, nil
   114  		}
   115  	}
   116  }
   117  
   118  func withFakeIP(fakePool *fakeip.Pool) middleware {
   119  	return func(next handler) handler {
   120  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
   121  			q := r.Question[0]
   122  
   123  			host := strings.TrimRight(q.Name, ".")
   124  			if fakePool.ShouldSkipped(host) {
   125  				return next(ctx, r)
   126  			}
   127  
   128  			switch q.Qtype {
   129  			case D.TypeAAAA, D.TypeSVCB, D.TypeHTTPS:
   130  				return handleMsgWithEmptyAnswer(r), nil
   131  			}
   132  
   133  			if q.Qtype != D.TypeA {
   134  				return next(ctx, r)
   135  			}
   136  
   137  			rr := &D.A{}
   138  			rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
   139  			ip := fakePool.Lookup(host)
   140  			rr.A = ip.AsSlice()
   141  			msg := r.Copy()
   142  			msg.Answer = []D.RR{rr}
   143  
   144  			ctx.SetType(context.DNSTypeFakeIP)
   145  			setMsgTTL(msg, 3)
   146  			msg.SetRcode(r, D.RcodeSuccess)
   147  			msg.Authoritative = true
   148  			msg.RecursionAvailable = true
   149  
   150  			return msg, nil
   151  		}
   152  	}
   153  }
   154  
   155  func withResolver(resolver *Resolver) handler {
   156  	return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
   157  		ctx.SetType(context.DNSTypeRaw)
   158  		q := r.Question[0]
   159  
   160  		// return an empty AAAA msg when ipv6 disabled
   161  		if !resolver.ipv6 && q.Qtype == D.TypeAAAA {
   162  			return handleMsgWithEmptyAnswer(r), nil
   163  		}
   164  
   165  		msg, _, err := resolver.Exchange(r)
   166  		if err != nil {
   167  			log.Debug().Err(err).
   168  				Str("name", q.Name).
   169  				Str("qClass", D.Class(q.Qclass).String()).
   170  				Str("qType", D.Type(q.Qtype).String()).
   171  				Msg("[DNS] exchange failed")
   172  			return nil, err
   173  		}
   174  		msg.SetRcode(r, msg.Rcode)
   175  		msg.Authoritative = true
   176  
   177  		return msg, nil
   178  	}
   179  }
   180  
   181  func compose(middlewares []middleware, endpoint handler) handler {
   182  	length := len(middlewares)
   183  	h := endpoint
   184  	for i := length - 1; i >= 0; i-- {
   185  		mMiddleware := middlewares[i]
   186  		h = mMiddleware(h)
   187  	}
   188  
   189  	return h
   190  }
   191  
   192  func newHandler(resolver *Resolver, mapper *ResolverEnhancer) handler {
   193  	middlewares := []middleware{}
   194  
   195  	if resolver.hosts != nil {
   196  		middlewares = append(middlewares, withHosts(resolver.hosts, mapper.mapping))
   197  	}
   198  
   199  	if mapper.mode == C.DNSFakeIP {
   200  		middlewares = append(middlewares, withFakeIP(mapper.fakePool))
   201  		middlewares = append(middlewares, withMapping(mapper.mapping))
   202  	}
   203  
   204  	return compose(middlewares, withResolver(resolver))
   205  }