github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/components/resolver/fakeip.go (about)

     1  package resolver
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/netip"
     8  	"slices"
     9  	"strings"
    10  
    11  	"github.com/Asutorufa/yuhaiin/pkg/log"
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/dns"
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/trie/domain"
    15  	pc "github.com/Asutorufa/yuhaiin/pkg/protos/config"
    16  	"github.com/Asutorufa/yuhaiin/pkg/utils/cache"
    17  	"github.com/Asutorufa/yuhaiin/pkg/utils/yerror"
    18  	"golang.org/x/net/dns/dnsmessage"
    19  )
    20  
    21  type Fakedns struct {
    22  	enabled  bool
    23  	fake     *dns.FakeDNS
    24  	dialer   netapi.Proxy
    25  	upstream netapi.Resolver
    26  	cache    *cache.Cache
    27  	cachev6  *cache.Cache
    28  
    29  	whitelistSlice []string
    30  	whitelist      *domain.Fqdn[struct{}]
    31  }
    32  
    33  func NewFakeDNS(dialer netapi.Proxy, upstream netapi.Resolver, bbolt, bboltv6 *cache.Cache) *Fakedns {
    34  	return &Fakedns{
    35  		fake: dns.NewFakeDNS(upstream,
    36  			yerror.Ignore(netip.ParsePrefix("10.2.0.1/24")),
    37  			yerror.Ignore(netip.ParsePrefix("fc00::/64")),
    38  			bbolt, bboltv6),
    39  		dialer:    dialer,
    40  		upstream:  upstream,
    41  		cache:     bbolt,
    42  		cachev6:   bboltv6,
    43  		whitelist: domain.NewDomainMapper[struct{}](),
    44  	}
    45  }
    46  
    47  func (f *Fakedns) Update(c *pc.Setting) {
    48  	f.enabled = c.Dns.Fakedns
    49  
    50  	if !slices.Equal(c.Dns.FakednsWhitelist, f.whitelistSlice) {
    51  		log.Info("update fakedns whitelist", "old", f.whitelistSlice, "new", c.Dns.FakednsWhitelist)
    52  
    53  		d := domain.NewDomainMapper[struct{}]()
    54  
    55  		for _, v := range c.Dns.FakednsWhitelist {
    56  			d.Insert(v, struct{}{})
    57  		}
    58  
    59  		f.whitelist = d
    60  		f.whitelistSlice = c.Dns.FakednsWhitelist
    61  	}
    62  
    63  	ipRange, er4 := netip.ParsePrefix(c.Dns.FakednsIpRange)
    64  	if er4 != nil {
    65  		log.Error("parse fakedns ip range failed", "err", er4)
    66  		ipRange, _ = netip.ParsePrefix("10.2.0.1/24")
    67  	}
    68  
    69  	ipv6Range, er6 := netip.ParsePrefix(c.Dns.FakednsIpv6Range)
    70  	if er6 != nil {
    71  		log.Error("parse fakedns PreferIPv6 range failed", "err", er6)
    72  		ipv6Range, _ = netip.ParsePrefix("fc00::/64")
    73  	}
    74  
    75  	if er4 != nil && er6 != nil {
    76  		return
    77  	}
    78  
    79  	f.fake = dns.NewFakeDNS(f.upstream, ipRange, ipv6Range, f.cache, f.cachev6)
    80  }
    81  
    82  func (f *Fakedns) resolver(ctx context.Context, domain string) netapi.Resolver {
    83  	if f.enabled || ctx.Value(netapi.ForceFakeIP{}) == true {
    84  		if _, ok := f.whitelist.SearchString(strings.TrimSuffix(domain, ".")); ok {
    85  			return f.upstream
    86  		}
    87  
    88  		return f.fake
    89  	}
    90  
    91  	return f.upstream
    92  }
    93  
    94  func (f *Fakedns) LookupIP(ctx context.Context, domain string, opts ...func(*netapi.LookupIPOption)) ([]net.IP, error) {
    95  	return f.resolver(ctx, domain).LookupIP(ctx, domain, opts...)
    96  }
    97  
    98  func (f *Fakedns) Raw(ctx context.Context, req dnsmessage.Question) (dnsmessage.Message, error) {
    99  	return f.resolver(ctx, req.Name.String()).Raw(ctx, req)
   100  }
   101  
   102  func (f *Fakedns) Close() error { return f.upstream.Close() }
   103  
   104  func (f *Fakedns) Dispatch(ctx context.Context, addr netapi.Address) (netapi.Address, error) {
   105  	return f.dialer.Dispatch(ctx, f.dispatchAddr(ctx, addr))
   106  }
   107  
   108  func (f *Fakedns) Conn(ctx context.Context, addr netapi.Address) (net.Conn, error) {
   109  	c, err := f.dialer.Conn(ctx, f.dispatchAddr(ctx, addr))
   110  	if err != nil {
   111  		return nil, fmt.Errorf("connect tcp to %s failed: %w", addr, err)
   112  	}
   113  
   114  	return c, nil
   115  }
   116  
   117  func (f *Fakedns) PacketConn(ctx context.Context, addr netapi.Address) (net.PacketConn, error) {
   118  	c, err := f.dialer.PacketConn(ctx, f.dispatchAddr(ctx, addr))
   119  	if err != nil {
   120  		return nil, fmt.Errorf("connect udp to %s failed: %w", addr, err)
   121  	}
   122  
   123  	c = &dispatchPacketConn{c, f.dispatchAddr}
   124  
   125  	return c, nil
   126  }
   127  
   128  func (f *Fakedns) dispatchAddr(ctx context.Context, addr netapi.Address) netapi.Address {
   129  	if addr.Type() == netapi.IP {
   130  		t, ok := f.fake.GetDomainFromIP(addr.AddrPort(ctx).V.Addr())
   131  		if ok {
   132  			r := addr.OverrideHostname(t)
   133  			netapi.StoreFromContext(ctx).
   134  				Add(netapi.FakeIPKey{}, addr).
   135  				Add(netapi.CurrentKey{}, r)
   136  			return r
   137  		}
   138  	}
   139  	return addr
   140  }
   141  
   142  type dispatchPacketConn struct {
   143  	net.PacketConn
   144  	dispatch func(context.Context, netapi.Address) netapi.Address
   145  }
   146  
   147  func (f *dispatchPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   148  	z, err := netapi.ParseSysAddr(addr)
   149  	if err != nil {
   150  		return 0, fmt.Errorf("parse addr failed: %w", err)
   151  	}
   152  
   153  	return f.PacketConn.WriteTo(b, f.dispatch(context.TODO(), z))
   154  }