github.com/igoogolx/clash@v1.19.8/dns/middleware.go (about)

     1  package dns
     2  
     3  import (
     4  	"net"
     5  	"strings"
     6  	"time"
     7  
     8  	"github.com/igoogolx/clash/common/cache"
     9  	"github.com/igoogolx/clash/component/fakeip"
    10  	"github.com/igoogolx/clash/component/trie"
    11  	C "github.com/igoogolx/clash/constant"
    12  	"github.com/igoogolx/clash/context"
    13  	"github.com/igoogolx/clash/log"
    14  
    15  	D "github.com/miekg/dns"
    16  )
    17  
    18  type (
    19  	handler    func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
    20  	middleware func(next handler) handler
    21  )
    22  
    23  func withHosts(hosts *trie.DomainTrie) middleware {
    24  	return func(next handler) handler {
    25  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
    26  			q := r.Question[0]
    27  
    28  			if !isIPRequest(q) {
    29  				return next(ctx, r)
    30  			}
    31  
    32  			record := hosts.Search(strings.TrimRight(q.Name, "."))
    33  			if record == nil {
    34  				return next(ctx, r)
    35  			}
    36  
    37  			ip := record.Data.(net.IP)
    38  			msg := r.Copy()
    39  
    40  			if v4 := ip.To4(); v4 != nil && q.Qtype == D.TypeA {
    41  				rr := &D.A{}
    42  				rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
    43  				rr.A = v4
    44  
    45  				msg.Answer = []D.RR{rr}
    46  			} else if v6 := ip.To16(); v6 != nil && q.Qtype == D.TypeAAAA {
    47  				rr := &D.AAAA{}
    48  				rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
    49  				rr.AAAA = v6
    50  
    51  				msg.Answer = []D.RR{rr}
    52  			} else {
    53  				return next(ctx, r)
    54  			}
    55  
    56  			ctx.SetType(context.DNSTypeHost)
    57  			msg.SetRcode(r, D.RcodeSuccess)
    58  			msg.Authoritative = true
    59  			msg.RecursionAvailable = true
    60  
    61  			return msg, nil
    62  		}
    63  	}
    64  }
    65  
    66  func withMapping(mapping *cache.LruCache) middleware {
    67  	return func(next handler) handler {
    68  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
    69  			q := r.Question[0]
    70  
    71  			if !isIPRequest(q) {
    72  				return next(ctx, r)
    73  			}
    74  
    75  			msg, err := next(ctx, r)
    76  			if err != nil {
    77  				return nil, err
    78  			}
    79  
    80  			host := strings.TrimRight(q.Name, ".")
    81  
    82  			for _, ans := range msg.Answer {
    83  				var ip net.IP
    84  				var ttl uint32
    85  
    86  				switch a := ans.(type) {
    87  				case *D.A:
    88  					ip = a.A
    89  					ttl = a.Hdr.Ttl
    90  					if !ip.IsGlobalUnicast() {
    91  						continue
    92  					}
    93  				case *D.AAAA:
    94  					ip = a.AAAA
    95  					ttl = a.Hdr.Ttl
    96  					if !ip.IsGlobalUnicast() {
    97  						continue
    98  					}
    99  				default:
   100  					continue
   101  				}
   102  
   103  				if ttl < 1 {
   104  					ttl = 1
   105  				}
   106  				mapping.SetWithExpire(ip.String(), host, time.Now().Add(time.Second*time.Duration(ttl)))
   107  			}
   108  
   109  			return msg, nil
   110  		}
   111  	}
   112  }
   113  
   114  func withFakeIP(fakePool *fakeip.Pool) middleware {
   115  	return func(next handler) handler {
   116  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
   117  			q := r.Question[0]
   118  
   119  			host := strings.TrimRight(q.Name, ".")
   120  			if fakePool.ShouldSkipped(host) {
   121  				return next(ctx, r)
   122  			}
   123  
   124  			switch q.Qtype {
   125  			case D.TypeAAAA, D.TypeSVCB, D.TypeHTTPS:
   126  				return handleMsgWithEmptyAnswer(r), nil
   127  			}
   128  
   129  			if q.Qtype != D.TypeA {
   130  				return next(ctx, r)
   131  			}
   132  
   133  			rr := &D.A{}
   134  			rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
   135  			ip := fakePool.Lookup(host)
   136  			rr.A = ip
   137  			msg := r.Copy()
   138  			msg.Answer = []D.RR{rr}
   139  
   140  			ctx.SetType(context.DNSTypeFakeIP)
   141  			setMsgTTL(msg, 1)
   142  			msg.SetRcode(r, D.RcodeSuccess)
   143  			msg.Authoritative = true
   144  			msg.RecursionAvailable = true
   145  
   146  			return msg, nil
   147  		}
   148  	}
   149  }
   150  
   151  func withResolver(resolver *Resolver) handler {
   152  	return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
   153  		ctx.SetType(context.DNSTypeRaw)
   154  		q := r.Question[0]
   155  
   156  		// return a empty AAAA msg when ipv6 disabled
   157  		if !resolver.ipv6 && q.Qtype == D.TypeAAAA {
   158  			return handleMsgWithEmptyAnswer(r), nil
   159  		}
   160  
   161  		msg, err := resolver.Exchange(r)
   162  		if err != nil {
   163  			log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err)
   164  			return msg, err
   165  		}
   166  		msg.SetRcode(r, msg.Rcode)
   167  		msg.Authoritative = true
   168  
   169  		return msg, nil
   170  	}
   171  }
   172  
   173  func compose(middlewares []middleware, endpoint handler) handler {
   174  	length := len(middlewares)
   175  	h := endpoint
   176  	for i := length - 1; i >= 0; i-- {
   177  		middleware := middlewares[i]
   178  		h = middleware(h)
   179  	}
   180  
   181  	return h
   182  }
   183  
   184  func newHandler(resolver *Resolver, mapper *ResolverEnhancer) handler {
   185  	middlewares := []middleware{}
   186  
   187  	if resolver.hosts != nil {
   188  		middlewares = append(middlewares, withHosts(resolver.hosts))
   189  	}
   190  
   191  	if mapper.mode == C.DNSFakeIP {
   192  		middlewares = append(middlewares, withFakeIP(mapper.fakePool))
   193  		middlewares = append(middlewares, withMapping(mapper.mapping))
   194  	}
   195  
   196  	return compose(middlewares, withResolver(resolver))
   197  }