github.com/chwjbn/xclash@v0.2.0/dns/middleware.go (about) 1 package dns 2 3 import ( 4 "net" 5 "strings" 6 "time" 7 8 "github.com/chwjbn/xclash/common/cache" 9 "github.com/chwjbn/xclash/component/fakeip" 10 "github.com/chwjbn/xclash/component/trie" 11 C "github.com/chwjbn/xclash/constant" 12 "github.com/chwjbn/xclash/context" 13 "github.com/chwjbn/xclash/log" 14 15 D "github.com/miekg/dns" 16 ) 17 18 type ( 19 handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) 20 middleware func(next handler) handler 21 ) 22 23 func withHosts(hosts *trie.DomainTrie) middleware { 24 return func(next handler) handler { 25 return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { 26 q := r.Question[0] 27 28 if !isIPRequest(q) { 29 return next(ctx, r) 30 } 31 32 record := hosts.Search(strings.TrimRight(q.Name, ".")) 33 if record == nil { 34 return next(ctx, r) 35 } 36 37 ip := record.Data.(net.IP) 38 msg := r.Copy() 39 40 if v4 := ip.To4(); v4 != nil && q.Qtype == D.TypeA { 41 rr := &D.A{} 42 rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} 43 rr.A = v4 44 45 msg.Answer = []D.RR{rr} 46 } else if v6 := ip.To16(); v6 != nil && q.Qtype == D.TypeAAAA { 47 rr := &D.AAAA{} 48 rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL} 49 rr.AAAA = v6 50 51 msg.Answer = []D.RR{rr} 52 } else { 53 return next(ctx, r) 54 } 55 56 ctx.SetType(context.DNSTypeHost) 57 msg.SetRcode(r, D.RcodeSuccess) 58 msg.Authoritative = true 59 msg.RecursionAvailable = true 60 61 return msg, nil 62 } 63 } 64 } 65 66 func withMapping(mapping *cache.LruCache) middleware { 67 return func(next handler) handler { 68 return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { 69 q := r.Question[0] 70 71 if !isIPRequest(q) { 72 return next(ctx, r) 73 } 74 75 msg, err := next(ctx, r) 76 if err != nil { 77 return nil, err 78 } 79 80 host := strings.TrimRight(q.Name, ".") 81 82 for _, ans := range msg.Answer { 83 var ip net.IP 84 var ttl uint32 85 86 switch a := ans.(type) { 87 case *D.A: 88 ip = a.A 89 ttl = a.Hdr.Ttl 90 case *D.AAAA: 91 ip = a.AAAA 92 ttl = a.Hdr.Ttl 93 default: 94 continue 95 } 96 97 mapping.SetWithExpire(ip.String(), host, time.Now().Add(time.Second*time.Duration(ttl))) 98 } 99 100 return msg, nil 101 } 102 } 103 } 104 105 func withFakeIP(fakePool *fakeip.Pool) middleware { 106 return func(next handler) handler { 107 return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { 108 q := r.Question[0] 109 110 host := strings.TrimRight(q.Name, ".") 111 if fakePool.ShouldSkipped(host) { 112 return next(ctx, r) 113 } 114 115 switch q.Qtype { 116 case D.TypeAAAA, D.TypeSVCB, D.TypeHTTPS: 117 return handleMsgWithEmptyAnswer(r), nil 118 } 119 120 if q.Qtype != D.TypeA { 121 return next(ctx, r) 122 } 123 124 rr := &D.A{} 125 rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} 126 ip := fakePool.Lookup(host) 127 rr.A = ip 128 msg := r.Copy() 129 msg.Answer = []D.RR{rr} 130 131 ctx.SetType(context.DNSTypeFakeIP) 132 setMsgTTL(msg, 1) 133 msg.SetRcode(r, D.RcodeSuccess) 134 msg.Authoritative = true 135 msg.RecursionAvailable = true 136 137 return msg, nil 138 } 139 } 140 } 141 142 func withResolver(resolver *Resolver) handler { 143 return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { 144 ctx.SetType(context.DNSTypeRaw) 145 q := r.Question[0] 146 147 // return a empty AAAA msg when ipv6 disabled 148 if !resolver.ipv6 && q.Qtype == D.TypeAAAA { 149 return handleMsgWithEmptyAnswer(r), nil 150 } 151 152 msg, err := resolver.Exchange(r) 153 if err != nil { 154 log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err) 155 return msg, err 156 } 157 msg.SetRcode(r, msg.Rcode) 158 msg.Authoritative = true 159 160 return msg, nil 161 } 162 } 163 164 func compose(middlewares []middleware, endpoint handler) handler { 165 length := len(middlewares) 166 h := endpoint 167 for i := length - 1; i >= 0; i-- { 168 middleware := middlewares[i] 169 h = middleware(h) 170 } 171 172 return h 173 } 174 175 func newHandler(resolver *Resolver, mapper *ResolverEnhancer) handler { 176 middlewares := []middleware{} 177 178 if resolver.hosts != nil { 179 middlewares = append(middlewares, withHosts(resolver.hosts)) 180 } 181 182 if mapper.mode == C.DNSFakeIP { 183 middlewares = append(middlewares, withFakeIP(mapper.fakePool)) 184 } 185 186 if mapper.mode != C.DNSNormal { 187 middlewares = append(middlewares, withMapping(mapper.mapping)) 188 } 189 190 return compose(middlewares, withResolver(resolver)) 191 }