github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/components/resolver/resolver.go (about)

     1  package resolver
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/netip"
     8  
     9  	"github.com/Asutorufa/yuhaiin/pkg/components/shunt"
    10  	"github.com/Asutorufa/yuhaiin/pkg/log"
    11  	"github.com/Asutorufa/yuhaiin/pkg/net/dns"
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    13  	pc "github.com/Asutorufa/yuhaiin/pkg/protos/config"
    14  	"github.com/Asutorufa/yuhaiin/pkg/protos/config/bypass"
    15  	pd "github.com/Asutorufa/yuhaiin/pkg/protos/config/dns"
    16  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    17  	"golang.org/x/net/dns/dnsmessage"
    18  	"google.golang.org/protobuf/proto"
    19  )
    20  
    21  type Entry struct {
    22  	Resolver netapi.Resolver
    23  	Config   *pd.Dns
    24  }
    25  
    26  type Resolver struct {
    27  	ipv6            bool
    28  	dialer          netapi.Proxy
    29  	bootstrapConfig *pd.Dns
    30  	store           syncmap.SyncMap[string, *Entry]
    31  }
    32  
    33  func NewResolver(dialer netapi.Proxy) *Resolver {
    34  	return &Resolver{dialer: dialer}
    35  }
    36  
    37  var errorResolver = netapi.ErrorResolver(func(domain string) error {
    38  	return fmt.Errorf("%w: %s", netapi.ErrBlocked, domain)
    39  })
    40  var blockStr = bypass.Mode_block.String()
    41  
    42  func (r *Resolver) Get(str string) netapi.Resolver {
    43  	if str != "" {
    44  		if str == blockStr {
    45  			return errorResolver
    46  		}
    47  		z, ok := r.store.Load(str)
    48  		if ok {
    49  			return z.Resolver
    50  		}
    51  	}
    52  
    53  	z, ok := r.store.Load(bypass.Mode_proxy.String())
    54  	if ok {
    55  		return z.Resolver
    56  	}
    57  
    58  	return netapi.Bootstrap
    59  }
    60  
    61  func (r *Resolver) Close() error {
    62  	r.store.Range(func(k string, v *Entry) bool {
    63  		v.Resolver.Close()
    64  		return true
    65  	})
    66  
    67  	r.store = syncmap.SyncMap[string, *Entry]{}
    68  
    69  	return nil
    70  }
    71  
    72  func (r *Resolver) GetIPv6() bool {
    73  	return r.ipv6
    74  }
    75  
    76  func (r *Resolver) Update(c *pc.Setting) {
    77  	c.Dns.Resolver = map[string]*pd.Dns{
    78  		bypass.Mode_direct.String(): c.Dns.Local,
    79  		bypass.Mode_proxy.String():  c.Dns.Remote,
    80  	}
    81  
    82  	r.ipv6 = c.GetIpv6()
    83  
    84  	if !proto.Equal(r.bootstrapConfig, c.Dns.Bootstrap) {
    85  		dialer := &dialer{
    86  			Proxy: r.dialer,
    87  			addr: func(ctx context.Context, addr netapi.Address) {
    88  				netapi.StoreFromContext(ctx).Add("Component", "Resolver BOOTSTRAP")
    89  				netapi.StoreFromContext(ctx).Add(shunt.ForceModeKey{}, bypass.Mode_direct)
    90  				addr.SetResolver(netapi.InternetResolver)
    91  				addr.SetSrc(netapi.AddressSrcDNS)
    92  			},
    93  		}
    94  		z, err := newDNS("BOOTSTRAP", c.Dns.Bootstrap, dialer, r)
    95  		if err != nil {
    96  			log.Error("get bootstrap dns failed", "err", err)
    97  		} else {
    98  			old := netapi.Bootstrap
    99  			netapi.Bootstrap = z
   100  			old.Close()
   101  		}
   102  	}
   103  
   104  	for k, v := range c.Dns.Resolver {
   105  		entry, ok := r.store.Load(k)
   106  		if ok && proto.Equal(entry.Config, v) {
   107  			continue
   108  		}
   109  
   110  		if entry != nil {
   111  			if err := entry.Resolver.Close(); err != nil {
   112  				log.Error("close dns resolver failed", "key", k, "err", err)
   113  			}
   114  		}
   115  
   116  		r.store.Delete(k)
   117  
   118  		dialer := &dialer{
   119  			Proxy: r.dialer,
   120  			addr: func(ctx context.Context, addr netapi.Address) {
   121  				netapi.StoreFromContext(ctx).Add("Component", "Resolver "+k)
   122  				// force to use bootstrap dns, otherwise will dns query cycle
   123  				addr.SetResolver(netapi.Bootstrap)
   124  				addr.SetSrc(netapi.AddressSrcDNS)
   125  			},
   126  		}
   127  
   128  		z, err := newDNS(k, v, dialer, r)
   129  		if err != nil {
   130  			log.Error("get local dns failed", "err", err)
   131  		} else {
   132  			r.store.Store(k, &Entry{
   133  				Resolver: z,
   134  				Config:   v,
   135  			})
   136  		}
   137  	}
   138  
   139  	r.store.Range(func(key string, value *Entry) bool {
   140  		_, ok := c.Dns.Resolver[key]
   141  		if !ok {
   142  			if err := value.Resolver.Close(); err != nil {
   143  				log.Error("close dns resolver failed", "key", key, "err", err)
   144  			}
   145  			r.store.Delete(key)
   146  		}
   147  		return true
   148  	})
   149  }
   150  
   151  type dnsWrap struct {
   152  	name     string
   153  	dns      netapi.Resolver
   154  	resolver *Resolver
   155  }
   156  
   157  func wrap(name string, dns netapi.Resolver, v6 *Resolver) *dnsWrap {
   158  	return &dnsWrap{name: name, dns: dns, resolver: v6}
   159  }
   160  
   161  func (d *dnsWrap) LookupIP(ctx context.Context, host string, opts ...func(*netapi.LookupIPOption)) ([]net.IP, error) {
   162  	opt := func(opt *netapi.LookupIPOption) {
   163  		if d.resolver.GetIPv6() {
   164  			opt.AAAA = true
   165  		}
   166  
   167  		for _, o := range opts {
   168  			o(opt)
   169  		}
   170  	}
   171  
   172  	ips, err := d.dns.LookupIP(ctx, host, opt)
   173  	if err != nil {
   174  		return nil, fmt.Errorf("%s lookup failed: %w", d.name, err)
   175  	}
   176  
   177  	return ips, nil
   178  }
   179  
   180  func (d *dnsWrap) Raw(ctx context.Context, req dnsmessage.Question) (dnsmessage.Message, error) {
   181  	msg, err := d.dns.Raw(ctx, req)
   182  	if err != nil {
   183  		return dnsmessage.Message{}, fmt.Errorf("%s do raw dns request failed: %w", d.name, err)
   184  	}
   185  
   186  	return msg, nil
   187  }
   188  
   189  func (d *dnsWrap) Close() error {
   190  	if d.dns != nil {
   191  		return d.dns.Close()
   192  	}
   193  
   194  	return nil
   195  }
   196  
   197  func newDNS(name string, dc *pd.Dns, dialer netapi.Proxy, resovler *Resolver) (netapi.Resolver, error) {
   198  	subnet, err := netip.ParsePrefix(dc.Subnet)
   199  	if err != nil {
   200  		p, err := netip.ParseAddr(dc.Subnet)
   201  		if err == nil {
   202  			subnet = netip.PrefixFrom(p, p.BitLen())
   203  		}
   204  	}
   205  	r, err := dns.New(
   206  		dns.Config{
   207  			Type:       dc.Type,
   208  			Name:       name,
   209  			Host:       dc.Host,
   210  			Servername: dc.TlsServername,
   211  			Subnet:     subnet,
   212  			Dialer:     dialer,
   213  		})
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  
   218  	return wrap(name, r, resovler), nil
   219  }
   220  
   221  type dialer struct {
   222  	netapi.Proxy
   223  	addr func(ctx context.Context, addr netapi.Address)
   224  }
   225  
   226  func (d *dialer) Conn(ctx context.Context, addr netapi.Address) (net.Conn, error) {
   227  	ctx = netapi.NewStore(ctx)
   228  	d.addr(ctx, addr)
   229  	return d.Proxy.Conn(ctx, addr)
   230  }
   231  
   232  func (d *dialer) PacketConn(ctx context.Context, addr netapi.Address) (net.PacketConn, error) {
   233  	ctx = netapi.NewStore(ctx)
   234  	d.addr(ctx, addr)
   235  	return d.Proxy.PacketConn(ctx, addr)
   236  }