github.com/chwjbn/xclash@v0.2.0/dns/middleware.go (about)

     1  package dns
     2  
     3  import (
     4  	"net"
     5  	"strings"
     6  	"time"
     7  
     8  	"github.com/chwjbn/xclash/common/cache"
     9  	"github.com/chwjbn/xclash/component/fakeip"
    10  	"github.com/chwjbn/xclash/component/trie"
    11  	C "github.com/chwjbn/xclash/constant"
    12  	"github.com/chwjbn/xclash/context"
    13  	"github.com/chwjbn/xclash/log"
    14  
    15  	D "github.com/miekg/dns"
    16  )
    17  
    18  type (
    19  	handler    func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
    20  	middleware func(next handler) handler
    21  )
    22  
    23  func withHosts(hosts *trie.DomainTrie) middleware {
    24  	return func(next handler) handler {
    25  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
    26  			q := r.Question[0]
    27  
    28  			if !isIPRequest(q) {
    29  				return next(ctx, r)
    30  			}
    31  
    32  			record := hosts.Search(strings.TrimRight(q.Name, "."))
    33  			if record == nil {
    34  				return next(ctx, r)
    35  			}
    36  
    37  			ip := record.Data.(net.IP)
    38  			msg := r.Copy()
    39  
    40  			if v4 := ip.To4(); v4 != nil && q.Qtype == D.TypeA {
    41  				rr := &D.A{}
    42  				rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
    43  				rr.A = v4
    44  
    45  				msg.Answer = []D.RR{rr}
    46  			} else if v6 := ip.To16(); v6 != nil && q.Qtype == D.TypeAAAA {
    47  				rr := &D.AAAA{}
    48  				rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
    49  				rr.AAAA = v6
    50  
    51  				msg.Answer = []D.RR{rr}
    52  			} else {
    53  				return next(ctx, r)
    54  			}
    55  
    56  			ctx.SetType(context.DNSTypeHost)
    57  			msg.SetRcode(r, D.RcodeSuccess)
    58  			msg.Authoritative = true
    59  			msg.RecursionAvailable = true
    60  
    61  			return msg, nil
    62  		}
    63  	}
    64  }
    65  
    66  func withMapping(mapping *cache.LruCache) middleware {
    67  	return func(next handler) handler {
    68  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
    69  			q := r.Question[0]
    70  
    71  			if !isIPRequest(q) {
    72  				return next(ctx, r)
    73  			}
    74  
    75  			msg, err := next(ctx, r)
    76  			if err != nil {
    77  				return nil, err
    78  			}
    79  
    80  			host := strings.TrimRight(q.Name, ".")
    81  
    82  			for _, ans := range msg.Answer {
    83  				var ip net.IP
    84  				var ttl uint32
    85  
    86  				switch a := ans.(type) {
    87  				case *D.A:
    88  					ip = a.A
    89  					ttl = a.Hdr.Ttl
    90  				case *D.AAAA:
    91  					ip = a.AAAA
    92  					ttl = a.Hdr.Ttl
    93  				default:
    94  					continue
    95  				}
    96  
    97  				mapping.SetWithExpire(ip.String(), host, time.Now().Add(time.Second*time.Duration(ttl)))
    98  			}
    99  
   100  			return msg, nil
   101  		}
   102  	}
   103  }
   104  
   105  func withFakeIP(fakePool *fakeip.Pool) middleware {
   106  	return func(next handler) handler {
   107  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
   108  			q := r.Question[0]
   109  
   110  			host := strings.TrimRight(q.Name, ".")
   111  			if fakePool.ShouldSkipped(host) {
   112  				return next(ctx, r)
   113  			}
   114  
   115  			switch q.Qtype {
   116  			case D.TypeAAAA, D.TypeSVCB, D.TypeHTTPS:
   117  				return handleMsgWithEmptyAnswer(r), nil
   118  			}
   119  
   120  			if q.Qtype != D.TypeA {
   121  				return next(ctx, r)
   122  			}
   123  
   124  			rr := &D.A{}
   125  			rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
   126  			ip := fakePool.Lookup(host)
   127  			rr.A = ip
   128  			msg := r.Copy()
   129  			msg.Answer = []D.RR{rr}
   130  
   131  			ctx.SetType(context.DNSTypeFakeIP)
   132  			setMsgTTL(msg, 1)
   133  			msg.SetRcode(r, D.RcodeSuccess)
   134  			msg.Authoritative = true
   135  			msg.RecursionAvailable = true
   136  
   137  			return msg, nil
   138  		}
   139  	}
   140  }
   141  
   142  func withResolver(resolver *Resolver) handler {
   143  	return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
   144  		ctx.SetType(context.DNSTypeRaw)
   145  		q := r.Question[0]
   146  
   147  		// return a empty AAAA msg when ipv6 disabled
   148  		if !resolver.ipv6 && q.Qtype == D.TypeAAAA {
   149  			return handleMsgWithEmptyAnswer(r), nil
   150  		}
   151  
   152  		msg, err := resolver.Exchange(r)
   153  		if err != nil {
   154  			log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err)
   155  			return msg, err
   156  		}
   157  		msg.SetRcode(r, msg.Rcode)
   158  		msg.Authoritative = true
   159  
   160  		return msg, nil
   161  	}
   162  }
   163  
   164  func compose(middlewares []middleware, endpoint handler) handler {
   165  	length := len(middlewares)
   166  	h := endpoint
   167  	for i := length - 1; i >= 0; i-- {
   168  		middleware := middlewares[i]
   169  		h = middleware(h)
   170  	}
   171  
   172  	return h
   173  }
   174  
   175  func newHandler(resolver *Resolver, mapper *ResolverEnhancer) handler {
   176  	middlewares := []middleware{}
   177  
   178  	if resolver.hosts != nil {
   179  		middlewares = append(middlewares, withHosts(resolver.hosts))
   180  	}
   181  
   182  	if mapper.mode == C.DNSFakeIP {
   183  		middlewares = append(middlewares, withFakeIP(mapper.fakePool))
   184  	}
   185  
   186  	if mapper.mode != C.DNSNormal {
   187  		middlewares = append(middlewares, withMapping(mapper.mapping))
   188  	}
   189  
   190  	return compose(middlewares, withResolver(resolver))
   191  }