github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/fake.go (about) 1 package dns 2 3 import ( 4 "context" 5 "fmt" 6 "math" 7 "net" 8 "net/netip" 9 "strings" 10 "sync" 11 "unsafe" 12 13 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 14 "github.com/Asutorufa/yuhaiin/pkg/utils/cache" 15 "github.com/Asutorufa/yuhaiin/pkg/utils/lru" 16 "golang.org/x/net/dns/dnsmessage" 17 ) 18 19 var _ netapi.Resolver = (*FakeDNS)(nil) 20 21 type FakeDNS struct { 22 netapi.Resolver 23 ipv4 *FakeIPPool 24 ipv6 *FakeIPPool 25 } 26 27 func NewFakeDNS( 28 upStreamDo netapi.Resolver, 29 ipRange netip.Prefix, 30 ipv6Range netip.Prefix, 31 bbolt, bboltv6 *cache.Cache, 32 ) *FakeDNS { 33 return &FakeDNS{upStreamDo, NewFakeIPPool(ipRange, bbolt), NewFakeIPPool(ipv6Range, bboltv6)} 34 } 35 36 func (f *FakeDNS) LookupIP(_ context.Context, domain string, opts ...func(*netapi.LookupIPOption)) ([]net.IP, error) { 37 opt := &netapi.LookupIPOption{} 38 for _, optf := range opts { 39 optf(opt) 40 } 41 42 if opt.AAAA && !opt.A { 43 return []net.IP{f.ipv6.GetFakeIPForDomain(domain).AsSlice()}, nil 44 } 45 46 if opt.A && !opt.AAAA { 47 return []net.IP{f.ipv4.GetFakeIPForDomain(domain).AsSlice()}, nil 48 } 49 50 return []net.IP{f.ipv4.GetFakeIPForDomain(domain).AsSlice(), f.ipv6.GetFakeIPForDomain(domain).AsSlice()}, nil 51 } 52 53 func (f *FakeDNS) Raw(ctx context.Context, req dnsmessage.Question) (dnsmessage.Message, error) { 54 if req.Type != dnsmessage.TypeA && req.Type != dnsmessage.TypeAAAA && req.Type != dnsmessage.TypePTR { 55 return f.Resolver.Raw(ctx, req) 56 } 57 58 newAnswer := func(resource dnsmessage.ResourceBody) dnsmessage.Message { 59 msg := dnsmessage.Message{ 60 Header: dnsmessage.Header{ 61 ID: 0, 62 Response: true, 63 Authoritative: false, 64 RecursionDesired: false, 65 RCode: dnsmessage.RCodeSuccess, 66 RecursionAvailable: false, 67 }, 68 Questions: []dnsmessage.Question{ 69 { 70 Name: req.Name, 71 Type: req.Type, 72 Class: dnsmessage.ClassINET, 73 }, 74 }, 75 } 76 77 answer := dnsmessage.Resource{ 78 Header: dnsmessage.ResourceHeader{ 79 Name: req.Name, 80 Class: dnsmessage.ClassINET, 81 TTL: 600, 82 Type: req.Type, 83 }, 84 Body: resource, 85 } 86 87 msg.Answers = append(msg.Answers, answer) 88 89 return msg 90 } 91 92 if req.Type == dnsmessage.TypePTR { 93 domain, err := f.LookupPtr(req.Name.String()) 94 if err != nil { 95 return f.Resolver.Raw(ctx, req) 96 } 97 98 msg := newAnswer(&dnsmessage.PTRResource{ 99 PTR: dnsmessage.MustNewName(domain + "."), 100 }) 101 102 return msg, nil 103 } 104 if req.Type == dnsmessage.TypeAAAA { 105 ip := f.ipv6.GetFakeIPForDomain(strings.TrimSuffix(req.Name.String(), ".")) 106 return newAnswer(&dnsmessage.AAAAResource{AAAA: ip.As16()}), nil 107 } 108 109 if req.Type == dnsmessage.TypeA { 110 ip := f.ipv4.GetFakeIPForDomain(strings.TrimSuffix(req.Name.String(), ".")) 111 return newAnswer(&dnsmessage.AResource{A: ip.As4()}), nil 112 } 113 114 return f.Resolver.Raw(ctx, req) 115 } 116 117 func (f *FakeDNS) GetDomainFromIP(ip netip.Addr) (string, bool) { 118 if ip.Unmap().Is6() { 119 return f.ipv6.GetDomainFromIP(ip) 120 } else { 121 return f.ipv4.GetDomainFromIP(ip) 122 } 123 } 124 125 var hex = map[byte]byte{ 126 '0': 0, 127 '1': 1, 128 '2': 2, 129 '3': 3, 130 '4': 4, 131 '5': 5, 132 '6': 6, 133 '7': 7, 134 '8': 8, 135 '9': 9, 136 'A': 10, 137 'a': 10, 138 'b': 11, 139 'B': 11, 140 'C': 12, 141 'c': 12, 142 'D': 13, 143 'd': 13, 144 'e': 14, 145 'E': 14, 146 'f': 15, 147 'F': 15, 148 } 149 150 func RetrieveIPFromPtr(name string) (net.IP, error) { 151 i := strings.Index(name, "ip6.arpa.") 152 if i != -1 && len(name[:i]) == 64 { 153 var ip [16]byte 154 for i := range ip { 155 ip[i] = hex[name[62-i*4]]*16 + hex[name[62-i*4-2]] 156 } 157 return net.IP(ip[:]), nil 158 } 159 160 if i = strings.Index(name, "in-addr.arpa."); i == -1 { 161 return nil, fmt.Errorf("ptr format failed: %s", name) 162 } 163 164 var ip [4]byte 165 var dotCount uint8 166 167 for _, v := range name[:i] { 168 if dotCount > 3 { 169 break 170 } 171 172 if v == '.' { 173 dotCount++ 174 } else { 175 ip[3-dotCount] = ip[3-dotCount]*10 + hex[byte(v)] 176 } 177 } 178 179 return net.IP(ip[:]), nil 180 } 181 182 func (f *FakeDNS) LookupPtr(name string) (string, error) { 183 ip, err := RetrieveIPFromPtr(name) 184 if err != nil { 185 return "", err 186 } 187 188 ipAddr, ok := netip.AddrFromSlice(ip) 189 if !ok { 190 return "", fmt.Errorf("parse netip.Addr from bytes failed") 191 } 192 193 r, ok := f.ipv4.GetDomainFromIP(ipAddr.Unmap()) 194 if ok { 195 return r, nil 196 } 197 198 r, ok = f.ipv6.GetDomainFromIP(ipAddr.Unmap()) 199 if ok { 200 return r, nil 201 } 202 203 return r, fmt.Errorf("ptr not found") 204 } 205 206 func (f *FakeDNS) Close() error { return nil } 207 208 type FakeIPPool struct { 209 prefix netip.Prefix 210 current netip.Addr 211 domainToIP *fakeLru 212 213 mu sync.Mutex 214 } 215 216 func NewFakeIPPool(prefix netip.Prefix, bbolt *cache.Cache) *FakeIPPool { 217 if bbolt == nil { 218 bbolt = cache.NewCache(nil, "") 219 } 220 221 prefix = prefix.Masked() 222 223 lenSize := 32 224 if prefix.Addr().Is6() { 225 lenSize = 128 226 } 227 228 var lruSize uint 229 if prefix.Bits() == lenSize { 230 lruSize = 0 231 } else { 232 lruSize = uint(math.Pow(2, float64(lenSize-prefix.Bits())) - 1) 233 } 234 235 return &FakeIPPool{ 236 prefix: prefix, 237 current: prefix.Addr().Prev(), 238 domainToIP: newFakeLru(lruSize, bbolt), 239 } 240 } 241 242 func (n *FakeIPPool) GetFakeIPForDomain(s string) netip.Addr { 243 if z, ok := n.domainToIP.Load(s); ok { 244 return z 245 } 246 247 n.mu.Lock() 248 defer n.mu.Unlock() 249 250 if z, ok := n.domainToIP.Load(s); ok { 251 return z 252 } 253 254 if v, ok := n.domainToIP.LastPopValue(); ok { 255 n.domainToIP.Add(s, v) 256 return v 257 } 258 259 for { 260 addr := n.current.Next() 261 262 if !n.prefix.Contains(addr) { 263 n.current = n.prefix.Addr().Prev() 264 continue 265 } 266 267 n.current = addr 268 269 if !n.domainToIP.ValueExist(addr) { 270 n.domainToIP.Add(s, addr) 271 return addr 272 } 273 } 274 } 275 276 func (n *FakeIPPool) GetDomainFromIP(ip netip.Addr) (string, bool) { 277 if !n.prefix.Contains(ip) { 278 return "", false 279 } 280 281 return n.domainToIP.ReverseLoad(ip.Unmap()) 282 } 283 284 func (n *FakeIPPool) LRU() *lru.LRU[string, netip.Addr] { return n.domainToIP.LRU } 285 286 type fakeLru struct { 287 LRU *lru.LRU[string, netip.Addr] 288 bbolt *cache.Cache 289 290 Size uint 291 } 292 293 func newFakeLru(size uint, bbolt *cache.Cache) *fakeLru { 294 z := &fakeLru{Size: size, bbolt: bbolt} 295 296 if size > 0 { 297 z.LRU = lru.New( 298 lru.WithCapacity[string, netip.Addr](size), 299 lru.WithOnRemove(func(s string, v netip.Addr) { bbolt.Delete([]byte(s), v.AsSlice()) }), 300 ) 301 } 302 303 return z 304 } 305 306 func (f *fakeLru) Load(host string) (netip.Addr, bool) { 307 if f.Size <= 0 { 308 return netip.Addr{}, false 309 } 310 311 z, ok := f.LRU.Load(host) 312 if ok { 313 return z, ok 314 } 315 316 if ip, ok := netip.AddrFromSlice(f.bbolt.Get(unsafe.Slice(unsafe.StringData(host), len(host)))); ok { 317 ip = ip.Unmap() 318 f.LRU.Add(host, ip) 319 return ip, true 320 } 321 322 return netip.Addr{}, false 323 } 324 325 func (f *fakeLru) Add(host string, ip netip.Addr) { 326 if f.Size <= 0 { 327 return 328 } 329 f.LRU.Add(host, ip) 330 331 if f.bbolt != nil { 332 host, ip := []byte(host), ip.AsSlice() 333 f.bbolt.Put(host, ip) 334 f.bbolt.Put(ip, host) 335 } 336 } 337 338 func (f *fakeLru) ValueExist(ip netip.Addr) bool { 339 if f.Size <= 0 { 340 return false 341 } 342 343 if f.LRU.ValueExist(ip) { 344 return true 345 } 346 347 if host := f.bbolt.Get(ip.AsSlice()); host != nil { 348 f.LRU.Add(string(host), ip) 349 return true 350 } 351 352 return false 353 } 354 355 func (f *fakeLru) ReverseLoad(ip netip.Addr) (string, bool) { 356 if f.Size <= 0 { 357 return "", false 358 } 359 360 host, ok := f.LRU.ReverseLoad(ip) 361 if ok { 362 return host, ok 363 } 364 365 if host = string(f.bbolt.Get(ip.AsSlice())); host != "" { 366 f.LRU.Add(host, ip) 367 return host, true 368 } 369 370 return "", false 371 } 372 373 func (f *fakeLru) LastPopValue() (netip.Addr, bool) { 374 if f.Size <= 0 { 375 return netip.Addr{}, false 376 } 377 return f.LRU.LastPopValue() 378 }