github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/fake.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math"
     7  	"net"
     8  	"net/netip"
     9  	"strings"
    10  	"sync"
    11  	"unsafe"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    14  	"github.com/Asutorufa/yuhaiin/pkg/utils/cache"
    15  	"github.com/Asutorufa/yuhaiin/pkg/utils/lru"
    16  	"golang.org/x/net/dns/dnsmessage"
    17  )
    18  
    19  var _ netapi.Resolver = (*FakeDNS)(nil)
    20  
    21  type FakeDNS struct {
    22  	netapi.Resolver
    23  	ipv4 *FakeIPPool
    24  	ipv6 *FakeIPPool
    25  }
    26  
    27  func NewFakeDNS(
    28  	upStreamDo netapi.Resolver,
    29  	ipRange netip.Prefix,
    30  	ipv6Range netip.Prefix,
    31  	bbolt, bboltv6 *cache.Cache,
    32  ) *FakeDNS {
    33  	return &FakeDNS{upStreamDo, NewFakeIPPool(ipRange, bbolt), NewFakeIPPool(ipv6Range, bboltv6)}
    34  }
    35  
    36  func (f *FakeDNS) LookupIP(_ context.Context, domain string, opts ...func(*netapi.LookupIPOption)) ([]net.IP, error) {
    37  	opt := &netapi.LookupIPOption{}
    38  	for _, optf := range opts {
    39  		optf(opt)
    40  	}
    41  
    42  	if opt.AAAA && !opt.A {
    43  		return []net.IP{f.ipv6.GetFakeIPForDomain(domain).AsSlice()}, nil
    44  	}
    45  
    46  	if opt.A && !opt.AAAA {
    47  		return []net.IP{f.ipv4.GetFakeIPForDomain(domain).AsSlice()}, nil
    48  	}
    49  
    50  	return []net.IP{f.ipv4.GetFakeIPForDomain(domain).AsSlice(), f.ipv6.GetFakeIPForDomain(domain).AsSlice()}, nil
    51  }
    52  
    53  func (f *FakeDNS) Raw(ctx context.Context, req dnsmessage.Question) (dnsmessage.Message, error) {
    54  	if req.Type != dnsmessage.TypeA && req.Type != dnsmessage.TypeAAAA && req.Type != dnsmessage.TypePTR {
    55  		return f.Resolver.Raw(ctx, req)
    56  	}
    57  
    58  	newAnswer := func(resource dnsmessage.ResourceBody) dnsmessage.Message {
    59  		msg := dnsmessage.Message{
    60  			Header: dnsmessage.Header{
    61  				ID:                 0,
    62  				Response:           true,
    63  				Authoritative:      false,
    64  				RecursionDesired:   false,
    65  				RCode:              dnsmessage.RCodeSuccess,
    66  				RecursionAvailable: false,
    67  			},
    68  			Questions: []dnsmessage.Question{
    69  				{
    70  					Name:  req.Name,
    71  					Type:  req.Type,
    72  					Class: dnsmessage.ClassINET,
    73  				},
    74  			},
    75  		}
    76  
    77  		answer := dnsmessage.Resource{
    78  			Header: dnsmessage.ResourceHeader{
    79  				Name:  req.Name,
    80  				Class: dnsmessage.ClassINET,
    81  				TTL:   600,
    82  				Type:  req.Type,
    83  			},
    84  			Body: resource,
    85  		}
    86  
    87  		msg.Answers = append(msg.Answers, answer)
    88  
    89  		return msg
    90  	}
    91  
    92  	if req.Type == dnsmessage.TypePTR {
    93  		domain, err := f.LookupPtr(req.Name.String())
    94  		if err != nil {
    95  			return f.Resolver.Raw(ctx, req)
    96  		}
    97  
    98  		msg := newAnswer(&dnsmessage.PTRResource{
    99  			PTR: dnsmessage.MustNewName(domain + "."),
   100  		})
   101  
   102  		return msg, nil
   103  	}
   104  	if req.Type == dnsmessage.TypeAAAA {
   105  		ip := f.ipv6.GetFakeIPForDomain(strings.TrimSuffix(req.Name.String(), "."))
   106  		return newAnswer(&dnsmessage.AAAAResource{AAAA: ip.As16()}), nil
   107  	}
   108  
   109  	if req.Type == dnsmessage.TypeA {
   110  		ip := f.ipv4.GetFakeIPForDomain(strings.TrimSuffix(req.Name.String(), "."))
   111  		return newAnswer(&dnsmessage.AResource{A: ip.As4()}), nil
   112  	}
   113  
   114  	return f.Resolver.Raw(ctx, req)
   115  }
   116  
   117  func (f *FakeDNS) GetDomainFromIP(ip netip.Addr) (string, bool) {
   118  	if ip.Unmap().Is6() {
   119  		return f.ipv6.GetDomainFromIP(ip)
   120  	} else {
   121  		return f.ipv4.GetDomainFromIP(ip)
   122  	}
   123  }
   124  
   125  var hex = map[byte]byte{
   126  	'0': 0,
   127  	'1': 1,
   128  	'2': 2,
   129  	'3': 3,
   130  	'4': 4,
   131  	'5': 5,
   132  	'6': 6,
   133  	'7': 7,
   134  	'8': 8,
   135  	'9': 9,
   136  	'A': 10,
   137  	'a': 10,
   138  	'b': 11,
   139  	'B': 11,
   140  	'C': 12,
   141  	'c': 12,
   142  	'D': 13,
   143  	'd': 13,
   144  	'e': 14,
   145  	'E': 14,
   146  	'f': 15,
   147  	'F': 15,
   148  }
   149  
   150  func RetrieveIPFromPtr(name string) (net.IP, error) {
   151  	i := strings.Index(name, "ip6.arpa.")
   152  	if i != -1 && len(name[:i]) == 64 {
   153  		var ip [16]byte
   154  		for i := range ip {
   155  			ip[i] = hex[name[62-i*4]]*16 + hex[name[62-i*4-2]]
   156  		}
   157  		return net.IP(ip[:]), nil
   158  	}
   159  
   160  	if i = strings.Index(name, "in-addr.arpa."); i == -1 {
   161  		return nil, fmt.Errorf("ptr format failed: %s", name)
   162  	}
   163  
   164  	var ip [4]byte
   165  	var dotCount uint8
   166  
   167  	for _, v := range name[:i] {
   168  		if dotCount > 3 {
   169  			break
   170  		}
   171  
   172  		if v == '.' {
   173  			dotCount++
   174  		} else {
   175  			ip[3-dotCount] = ip[3-dotCount]*10 + hex[byte(v)]
   176  		}
   177  	}
   178  
   179  	return net.IP(ip[:]), nil
   180  }
   181  
   182  func (f *FakeDNS) LookupPtr(name string) (string, error) {
   183  	ip, err := RetrieveIPFromPtr(name)
   184  	if err != nil {
   185  		return "", err
   186  	}
   187  
   188  	ipAddr, ok := netip.AddrFromSlice(ip)
   189  	if !ok {
   190  		return "", fmt.Errorf("parse netip.Addr from bytes failed")
   191  	}
   192  
   193  	r, ok := f.ipv4.GetDomainFromIP(ipAddr.Unmap())
   194  	if ok {
   195  		return r, nil
   196  	}
   197  
   198  	r, ok = f.ipv6.GetDomainFromIP(ipAddr.Unmap())
   199  	if ok {
   200  		return r, nil
   201  	}
   202  
   203  	return r, fmt.Errorf("ptr not found")
   204  }
   205  
   206  func (f *FakeDNS) Close() error { return nil }
   207  
   208  type FakeIPPool struct {
   209  	prefix     netip.Prefix
   210  	current    netip.Addr
   211  	domainToIP *fakeLru
   212  
   213  	mu sync.Mutex
   214  }
   215  
   216  func NewFakeIPPool(prefix netip.Prefix, bbolt *cache.Cache) *FakeIPPool {
   217  	if bbolt == nil {
   218  		bbolt = cache.NewCache(nil, "")
   219  	}
   220  
   221  	prefix = prefix.Masked()
   222  
   223  	lenSize := 32
   224  	if prefix.Addr().Is6() {
   225  		lenSize = 128
   226  	}
   227  
   228  	var lruSize uint
   229  	if prefix.Bits() == lenSize {
   230  		lruSize = 0
   231  	} else {
   232  		lruSize = uint(math.Pow(2, float64(lenSize-prefix.Bits())) - 1)
   233  	}
   234  
   235  	return &FakeIPPool{
   236  		prefix:     prefix,
   237  		current:    prefix.Addr().Prev(),
   238  		domainToIP: newFakeLru(lruSize, bbolt),
   239  	}
   240  }
   241  
   242  func (n *FakeIPPool) GetFakeIPForDomain(s string) netip.Addr {
   243  	if z, ok := n.domainToIP.Load(s); ok {
   244  		return z
   245  	}
   246  
   247  	n.mu.Lock()
   248  	defer n.mu.Unlock()
   249  
   250  	if z, ok := n.domainToIP.Load(s); ok {
   251  		return z
   252  	}
   253  
   254  	if v, ok := n.domainToIP.LastPopValue(); ok {
   255  		n.domainToIP.Add(s, v)
   256  		return v
   257  	}
   258  
   259  	for {
   260  		addr := n.current.Next()
   261  
   262  		if !n.prefix.Contains(addr) {
   263  			n.current = n.prefix.Addr().Prev()
   264  			continue
   265  		}
   266  
   267  		n.current = addr
   268  
   269  		if !n.domainToIP.ValueExist(addr) {
   270  			n.domainToIP.Add(s, addr)
   271  			return addr
   272  		}
   273  	}
   274  }
   275  
   276  func (n *FakeIPPool) GetDomainFromIP(ip netip.Addr) (string, bool) {
   277  	if !n.prefix.Contains(ip) {
   278  		return "", false
   279  	}
   280  
   281  	return n.domainToIP.ReverseLoad(ip.Unmap())
   282  }
   283  
   284  func (n *FakeIPPool) LRU() *lru.LRU[string, netip.Addr] { return n.domainToIP.LRU }
   285  
   286  type fakeLru struct {
   287  	LRU   *lru.LRU[string, netip.Addr]
   288  	bbolt *cache.Cache
   289  
   290  	Size uint
   291  }
   292  
   293  func newFakeLru(size uint, bbolt *cache.Cache) *fakeLru {
   294  	z := &fakeLru{Size: size, bbolt: bbolt}
   295  
   296  	if size > 0 {
   297  		z.LRU = lru.New(
   298  			lru.WithCapacity[string, netip.Addr](size),
   299  			lru.WithOnRemove(func(s string, v netip.Addr) { bbolt.Delete([]byte(s), v.AsSlice()) }),
   300  		)
   301  	}
   302  
   303  	return z
   304  }
   305  
   306  func (f *fakeLru) Load(host string) (netip.Addr, bool) {
   307  	if f.Size <= 0 {
   308  		return netip.Addr{}, false
   309  	}
   310  
   311  	z, ok := f.LRU.Load(host)
   312  	if ok {
   313  		return z, ok
   314  	}
   315  
   316  	if ip, ok := netip.AddrFromSlice(f.bbolt.Get(unsafe.Slice(unsafe.StringData(host), len(host)))); ok {
   317  		ip = ip.Unmap()
   318  		f.LRU.Add(host, ip)
   319  		return ip, true
   320  	}
   321  
   322  	return netip.Addr{}, false
   323  }
   324  
   325  func (f *fakeLru) Add(host string, ip netip.Addr) {
   326  	if f.Size <= 0 {
   327  		return
   328  	}
   329  	f.LRU.Add(host, ip)
   330  
   331  	if f.bbolt != nil {
   332  		host, ip := []byte(host), ip.AsSlice()
   333  		f.bbolt.Put(host, ip)
   334  		f.bbolt.Put(ip, host)
   335  	}
   336  }
   337  
   338  func (f *fakeLru) ValueExist(ip netip.Addr) bool {
   339  	if f.Size <= 0 {
   340  		return false
   341  	}
   342  
   343  	if f.LRU.ValueExist(ip) {
   344  		return true
   345  	}
   346  
   347  	if host := f.bbolt.Get(ip.AsSlice()); host != nil {
   348  		f.LRU.Add(string(host), ip)
   349  		return true
   350  	}
   351  
   352  	return false
   353  }
   354  
   355  func (f *fakeLru) ReverseLoad(ip netip.Addr) (string, bool) {
   356  	if f.Size <= 0 {
   357  		return "", false
   358  	}
   359  
   360  	host, ok := f.LRU.ReverseLoad(ip)
   361  	if ok {
   362  		return host, ok
   363  	}
   364  
   365  	if host = string(f.bbolt.Get(ip.AsSlice())); host != "" {
   366  		f.LRU.Add(host, ip)
   367  		return host, true
   368  	}
   369  
   370  	return "", false
   371  }
   372  
   373  func (f *fakeLru) LastPopValue() (netip.Addr, bool) {
   374  	if f.Size <= 0 {
   375  		return netip.Addr{}, false
   376  	}
   377  	return f.LRU.LastPopValue()
   378  }