github.com/metacubex/mihomo@v1.18.5/dns/middleware.go (about) 1 package dns 2 3 import ( 4 "net/netip" 5 "strings" 6 "time" 7 8 "github.com/metacubex/mihomo/common/lru" 9 "github.com/metacubex/mihomo/common/nnip" 10 "github.com/metacubex/mihomo/component/fakeip" 11 R "github.com/metacubex/mihomo/component/resolver" 12 C "github.com/metacubex/mihomo/constant" 13 "github.com/metacubex/mihomo/context" 14 "github.com/metacubex/mihomo/log" 15 16 D "github.com/miekg/dns" 17 ) 18 19 type ( 20 handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) 21 middleware func(next handler) handler 22 ) 23 24 func withHosts(hosts R.Hosts, mapping *lru.LruCache[netip.Addr, string]) middleware { 25 return func(next handler) handler { 26 return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { 27 q := r.Question[0] 28 29 if !isIPRequest(q) { 30 return next(ctx, r) 31 } 32 33 host := strings.TrimRight(q.Name, ".") 34 handleCName := func(resp *D.Msg, domain string) { 35 rr := &D.CNAME{} 36 rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeCNAME, Class: D.ClassINET, Ttl: 10} 37 rr.Target = domain + "." 38 resp.Answer = append([]D.RR{rr}, resp.Answer...) 39 } 40 record, ok := hosts.Search(host, q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA) 41 if !ok { 42 if record != nil && record.IsDomain { 43 // replace request domain 44 newR := r.Copy() 45 newR.Question[0].Name = record.Domain + "." 46 resp, err := next(ctx, newR) 47 if err == nil { 48 resp.Id = r.Id 49 resp.Question = r.Question 50 handleCName(resp, record.Domain) 51 } 52 return resp, err 53 } 54 return next(ctx, r) 55 } 56 57 msg := r.Copy() 58 handleIPs := func() { 59 for _, ipAddr := range record.IPs { 60 if ipAddr.Is4() && q.Qtype == D.TypeA { 61 rr := &D.A{} 62 rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10} 63 rr.A = ipAddr.AsSlice() 64 msg.Answer = append(msg.Answer, rr) 65 if mapping != nil { 66 mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10)) 67 } 68 } else if q.Qtype == D.TypeAAAA { 69 rr := &D.AAAA{} 70 rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10} 71 ip := ipAddr.As16() 72 rr.AAAA = ip[:] 73 msg.Answer = append(msg.Answer, rr) 74 if mapping != nil { 75 mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10)) 76 } 77 } 78 } 79 } 80 81 switch q.Qtype { 82 case D.TypeA: 83 handleIPs() 84 case D.TypeAAAA: 85 handleIPs() 86 case D.TypeCNAME: 87 handleCName(r, record.Domain) 88 default: 89 return next(ctx, r) 90 } 91 92 ctx.SetType(context.DNSTypeHost) 93 msg.SetRcode(r, D.RcodeSuccess) 94 msg.Authoritative = true 95 msg.RecursionAvailable = true 96 return msg, nil 97 } 98 } 99 } 100 101 func withMapping(mapping *lru.LruCache[netip.Addr, string]) middleware { 102 return func(next handler) handler { 103 return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { 104 q := r.Question[0] 105 106 if !isIPRequest(q) { 107 return next(ctx, r) 108 } 109 110 msg, err := next(ctx, r) 111 if err != nil { 112 return nil, err 113 } 114 115 host := strings.TrimRight(q.Name, ".") 116 117 for _, ans := range msg.Answer { 118 var ip netip.Addr 119 var ttl uint32 120 121 switch a := ans.(type) { 122 case *D.A: 123 ip = nnip.IpToAddr(a.A) 124 ttl = a.Hdr.Ttl 125 case *D.AAAA: 126 ip = nnip.IpToAddr(a.AAAA) 127 ttl = a.Hdr.Ttl 128 default: 129 continue 130 } 131 132 if ttl < 1 { 133 ttl = 1 134 } 135 136 mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*time.Duration(ttl))) 137 } 138 139 return msg, nil 140 } 141 } 142 } 143 144 func withFakeIP(fakePool *fakeip.Pool) middleware { 145 return func(next handler) handler { 146 return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { 147 q := r.Question[0] 148 149 host := strings.TrimRight(q.Name, ".") 150 if fakePool.ShouldSkipped(host) { 151 return next(ctx, r) 152 } 153 154 switch q.Qtype { 155 case D.TypeAAAA, D.TypeSVCB, D.TypeHTTPS: 156 return handleMsgWithEmptyAnswer(r), nil 157 } 158 159 if q.Qtype != D.TypeA { 160 return next(ctx, r) 161 } 162 163 rr := &D.A{} 164 rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} 165 ip := fakePool.Lookup(host) 166 rr.A = ip.AsSlice() 167 msg := r.Copy() 168 msg.Answer = []D.RR{rr} 169 170 ctx.SetType(context.DNSTypeFakeIP) 171 setMsgTTL(msg, 1) 172 msg.SetRcode(r, D.RcodeSuccess) 173 msg.Authoritative = true 174 msg.RecursionAvailable = true 175 176 return msg, nil 177 } 178 } 179 } 180 181 func withResolver(resolver *Resolver) handler { 182 return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { 183 ctx.SetType(context.DNSTypeRaw) 184 185 q := r.Question[0] 186 187 // return a empty AAAA msg when ipv6 disabled 188 if !resolver.ipv6 && q.Qtype == D.TypeAAAA { 189 return handleMsgWithEmptyAnswer(r), nil 190 } 191 192 msg, err := resolver.ExchangeContext(ctx, r) 193 if err != nil { 194 log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err) 195 return msg, err 196 } 197 msg.SetRcode(r, msg.Rcode) 198 msg.Authoritative = true 199 200 return msg, nil 201 } 202 } 203 204 func compose(middlewares []middleware, endpoint handler) handler { 205 length := len(middlewares) 206 h := endpoint 207 for i := length - 1; i >= 0; i-- { 208 middleware := middlewares[i] 209 h = middleware(h) 210 } 211 212 return h 213 } 214 215 func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler { 216 middlewares := []middleware{} 217 218 if resolver.hosts != nil { 219 middlewares = append(middlewares, withHosts(R.NewHosts(resolver.hosts), mapper.mapping)) 220 } 221 222 if mapper.mode == C.DNSFakeIP { 223 middlewares = append(middlewares, withFakeIP(mapper.fakePool)) 224 } 225 226 if mapper.mode != C.DNSNormal { 227 middlewares = append(middlewares, withMapping(mapper.mapping)) 228 } 229 230 return compose(middlewares, withResolver(resolver)) 231 }