github.com/kelleygo/clashcore@v1.0.2/dns/middleware.go (about)

     1  package dns
     2  
     3  import (
     4  	"net/netip"
     5  	"strings"
     6  	"time"
     7  
     8  	"github.com/kelleygo/clashcore/common/lru"
     9  	"github.com/kelleygo/clashcore/common/nnip"
    10  	"github.com/kelleygo/clashcore/component/fakeip"
    11  	R "github.com/kelleygo/clashcore/component/resolver"
    12  	C "github.com/kelleygo/clashcore/constant"
    13  	"github.com/kelleygo/clashcore/context"
    14  	"github.com/kelleygo/clashcore/log"
    15  
    16  	D "github.com/miekg/dns"
    17  )
    18  
    19  type (
    20  	handler    func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
    21  	middleware func(next handler) handler
    22  )
    23  
    24  func withHosts(hosts R.Hosts, mapping *lru.LruCache[netip.Addr, string]) middleware {
    25  	return func(next handler) handler {
    26  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
    27  			q := r.Question[0]
    28  
    29  			if !isIPRequest(q) {
    30  				return next(ctx, r)
    31  			}
    32  
    33  			host := strings.TrimRight(q.Name, ".")
    34  			handleCName := func(resp *D.Msg, domain string) {
    35  				rr := &D.CNAME{}
    36  				rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeCNAME, Class: D.ClassINET, Ttl: 10}
    37  				rr.Target = domain + "."
    38  				resp.Answer = append([]D.RR{rr}, resp.Answer...)
    39  			}
    40  			record, ok := hosts.Search(host, q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA)
    41  			if !ok {
    42  				if record != nil && record.IsDomain {
    43  					// replace request domain
    44  					newR := r.Copy()
    45  					newR.Question[0].Name = record.Domain + "."
    46  					resp, err := next(ctx, newR)
    47  					if err == nil {
    48  						resp.Id = r.Id
    49  						resp.Question = r.Question
    50  						handleCName(resp, record.Domain)
    51  					}
    52  					return resp, err
    53  				}
    54  				return next(ctx, r)
    55  			}
    56  
    57  			msg := r.Copy()
    58  			handleIPs := func() {
    59  				for _, ipAddr := range record.IPs {
    60  					if ipAddr.Is4() && q.Qtype == D.TypeA {
    61  						rr := &D.A{}
    62  						rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10}
    63  						rr.A = ipAddr.AsSlice()
    64  						msg.Answer = append(msg.Answer, rr)
    65  						if mapping != nil {
    66  							mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10))
    67  						}
    68  					} else if q.Qtype == D.TypeAAAA {
    69  						rr := &D.AAAA{}
    70  						rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10}
    71  						ip := ipAddr.As16()
    72  						rr.AAAA = ip[:]
    73  						msg.Answer = append(msg.Answer, rr)
    74  						if mapping != nil {
    75  							mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10))
    76  						}
    77  					}
    78  				}
    79  			}
    80  
    81  			switch q.Qtype {
    82  			case D.TypeA:
    83  				handleIPs()
    84  			case D.TypeAAAA:
    85  				handleIPs()
    86  			case D.TypeCNAME:
    87  				handleCName(r, record.Domain)
    88  			default:
    89  				return next(ctx, r)
    90  			}
    91  
    92  			ctx.SetType(context.DNSTypeHost)
    93  			msg.SetRcode(r, D.RcodeSuccess)
    94  			msg.Authoritative = true
    95  			msg.RecursionAvailable = true
    96  			return msg, nil
    97  		}
    98  	}
    99  }
   100  
   101  func withMapping(mapping *lru.LruCache[netip.Addr, string]) middleware {
   102  	return func(next handler) handler {
   103  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
   104  			q := r.Question[0]
   105  
   106  			if !isIPRequest(q) {
   107  				return next(ctx, r)
   108  			}
   109  
   110  			msg, err := next(ctx, r)
   111  			if err != nil {
   112  				return nil, err
   113  			}
   114  
   115  			host := strings.TrimRight(q.Name, ".")
   116  
   117  			for _, ans := range msg.Answer {
   118  				var ip netip.Addr
   119  				var ttl uint32
   120  
   121  				switch a := ans.(type) {
   122  				case *D.A:
   123  					ip = nnip.IpToAddr(a.A)
   124  					ttl = a.Hdr.Ttl
   125  				case *D.AAAA:
   126  					ip = nnip.IpToAddr(a.AAAA)
   127  					ttl = a.Hdr.Ttl
   128  				default:
   129  					continue
   130  				}
   131  
   132  				if ttl < 1 {
   133  					ttl = 1
   134  				}
   135  
   136  				mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*time.Duration(ttl)))
   137  			}
   138  
   139  			return msg, nil
   140  		}
   141  	}
   142  }
   143  
   144  func withFakeIP(fakePool *fakeip.Pool) middleware {
   145  	return func(next handler) handler {
   146  		return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
   147  			q := r.Question[0]
   148  
   149  			host := strings.TrimRight(q.Name, ".")
   150  			if fakePool.ShouldSkipped(host) {
   151  				return next(ctx, r)
   152  			}
   153  
   154  			switch q.Qtype {
   155  			case D.TypeAAAA, D.TypeSVCB, D.TypeHTTPS:
   156  				return handleMsgWithEmptyAnswer(r), nil
   157  			}
   158  
   159  			if q.Qtype != D.TypeA {
   160  				return next(ctx, r)
   161  			}
   162  
   163  			rr := &D.A{}
   164  			rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
   165  			ip := fakePool.Lookup(host)
   166  			rr.A = ip.AsSlice()
   167  			msg := r.Copy()
   168  			msg.Answer = []D.RR{rr}
   169  
   170  			ctx.SetType(context.DNSTypeFakeIP)
   171  			setMsgTTL(msg, 1)
   172  			msg.SetRcode(r, D.RcodeSuccess)
   173  			msg.Authoritative = true
   174  			msg.RecursionAvailable = true
   175  
   176  			return msg, nil
   177  		}
   178  	}
   179  }
   180  
   181  func withResolver(resolver *Resolver) handler {
   182  	return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
   183  		ctx.SetType(context.DNSTypeRaw)
   184  
   185  		q := r.Question[0]
   186  
   187  		// return a empty AAAA msg when ipv6 disabled
   188  		if !resolver.ipv6 && q.Qtype == D.TypeAAAA {
   189  			return handleMsgWithEmptyAnswer(r), nil
   190  		}
   191  
   192  		msg, err := resolver.ExchangeContext(ctx, r)
   193  		if err != nil {
   194  			log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err)
   195  			return msg, err
   196  		}
   197  		msg.SetRcode(r, msg.Rcode)
   198  		msg.Authoritative = true
   199  
   200  		return msg, nil
   201  	}
   202  }
   203  
   204  func compose(middlewares []middleware, endpoint handler) handler {
   205  	length := len(middlewares)
   206  	h := endpoint
   207  	for i := length - 1; i >= 0; i-- {
   208  		middleware := middlewares[i]
   209  		h = middleware(h)
   210  	}
   211  
   212  	return h
   213  }
   214  
   215  func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler {
   216  	middlewares := []middleware{}
   217  
   218  	if resolver.hosts != nil {
   219  		middlewares = append(middlewares, withHosts(R.NewHosts(resolver.hosts), mapper.mapping))
   220  	}
   221  
   222  	if mapper.mode == C.DNSFakeIP {
   223  		middlewares = append(middlewares, withFakeIP(mapper.fakePool))
   224  	}
   225  
   226  	if mapper.mode != C.DNSNormal {
   227  		middlewares = append(middlewares, withMapping(mapper.mapping))
   228  	}
   229  
   230  	return compose(middlewares, withResolver(resolver))
   231  }