github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/components/shunt/shunt.go (about)

     1  package shunt
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"os"
     8  	"slices"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/Asutorufa/yuhaiin/pkg/log"
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/trie"
    15  	pc "github.com/Asutorufa/yuhaiin/pkg/protos/config"
    16  	"github.com/Asutorufa/yuhaiin/pkg/protos/config/bypass"
    17  	"github.com/Asutorufa/yuhaiin/pkg/utils/convert"
    18  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    19  	"golang.org/x/exp/maps"
    20  	"golang.org/x/net/dns/dnsmessage"
    21  	"google.golang.org/protobuf/proto"
    22  )
    23  
    24  type modeMarkKey struct{}
    25  
    26  func (modeMarkKey) String() string { return "MODE" }
    27  
    28  type DOMAIN_MARK_KEY struct{}
    29  
    30  type IP_MARK_KEY struct{}
    31  
    32  func (IP_MARK_KEY) String() string { return "IP" }
    33  
    34  type ForceModeKey struct{}
    35  
    36  type Shunt struct {
    37  	resolveDomain bool
    38  	modifiedTime  int64
    39  
    40  	config       *bypass.BypassConfig
    41  	mapper       *trie.Trie[bypass.ModeEnum]
    42  	customMapper *trie.Trie[bypass.ModeEnum]
    43  
    44  	processMapper syncmap.SyncMap[string, bypass.ModeEnum]
    45  	ProcessDumper netapi.ProcessDumper
    46  
    47  	mu sync.RWMutex
    48  
    49  	r Resolver
    50  	d Dialer
    51  
    52  	tags map[string]struct{}
    53  }
    54  
    55  type Resolver interface {
    56  	Get(str string) netapi.Resolver
    57  }
    58  type Dialer interface {
    59  	Get(ctx context.Context, network string, str string, tag string) (netapi.Proxy, error)
    60  }
    61  
    62  func NewShunt(d Dialer, r Resolver, ProcessDumper netapi.ProcessDumper) *Shunt {
    63  	return &Shunt{
    64  		mapper:       trie.NewTrie[bypass.ModeEnum](),
    65  		customMapper: trie.NewTrie[bypass.ModeEnum](),
    66  		config: &bypass.BypassConfig{
    67  			Tcp: bypass.Mode_bypass,
    68  			Udp: bypass.Mode_bypass,
    69  		},
    70  		r:             r,
    71  		d:             d,
    72  		ProcessDumper: ProcessDumper,
    73  		tags:          make(map[string]struct{}),
    74  	}
    75  }
    76  
    77  func (s *Shunt) Update(c *pc.Setting) {
    78  	s.mu.Lock()
    79  	defer s.mu.Unlock()
    80  
    81  	s.resolveDomain = c.Dns.ResolveRemoteDomain
    82  
    83  	if !slices.EqualFunc(
    84  		s.config.CustomRuleV3,
    85  		c.Bypass.CustomRuleV3,
    86  		func(mc1, mc2 *bypass.ModeConfig) bool { return proto.Equal(mc1, mc2) },
    87  	) {
    88  		s.customMapper.Clear() //nolint:errcheck
    89  		s.processMapper = syncmap.SyncMap[string, bypass.ModeEnum]{}
    90  
    91  		for _, v := range c.Bypass.CustomRuleV3 {
    92  			mark := v.ToModeEnum()
    93  
    94  			if mark.GetTag() != "" {
    95  				s.tags[mark.GetTag()] = struct{}{}
    96  			}
    97  
    98  			for _, hostname := range v.Hostname {
    99  				if strings.HasPrefix(hostname, "process:") {
   100  					s.processMapper.Store(hostname[8:], mark)
   101  				} else {
   102  					s.customMapper.Insert(hostname, mark)
   103  				}
   104  			}
   105  		}
   106  	}
   107  
   108  	modifiedTime := s.modifiedTime
   109  	if stat, err := os.Stat(c.Bypass.BypassFile); err == nil {
   110  		modifiedTime = stat.ModTime().Unix()
   111  	}
   112  
   113  	if s.config.BypassFile != c.Bypass.BypassFile || s.modifiedTime != modifiedTime {
   114  		s.mapper.Clear() //nolint:errcheck
   115  		s.tags = make(map[string]struct{})
   116  		s.modifiedTime = modifiedTime
   117  		rangeRule(c.Bypass.BypassFile, func(s1 string, s2 bypass.ModeEnum) {
   118  			if strings.HasPrefix(s1, "process:") {
   119  				s.processMapper.Store(s1[8:], s2.Mode())
   120  			} else {
   121  				s.mapper.Insert(s1, s2)
   122  			}
   123  
   124  			if s2.GetTag() != "" {
   125  				s.tags[s2.GetTag()] = struct{}{}
   126  			}
   127  		})
   128  	}
   129  
   130  	s.config = c.Bypass
   131  }
   132  
   133  func (s *Shunt) Tags() []string { return maps.Keys(s.tags) }
   134  
   135  func (s *Shunt) Conn(ctx context.Context, host netapi.Address) (net.Conn, error) {
   136  	mode, host := s.dispatch(ctx, s.config.Tcp, host)
   137  
   138  	p, err := s.d.Get(ctx, "tcp", mode.Mode().String(), mode.GetTag())
   139  	if err != nil {
   140  		return nil, fmt.Errorf("dial %s failed: %w", host, err)
   141  	}
   142  
   143  	conn, err := p.Conn(ctx, host)
   144  	if err != nil {
   145  		return nil, fmt.Errorf("dial %s failed: %w", host, err)
   146  	}
   147  
   148  	return conn, nil
   149  }
   150  
   151  func (s *Shunt) PacketConn(ctx context.Context, host netapi.Address) (net.PacketConn, error) {
   152  	mode, host := s.dispatch(ctx, s.config.Udp, host)
   153  
   154  	p, err := s.d.Get(ctx, "udp", mode.Mode().String(), mode.GetTag())
   155  	if err != nil {
   156  		return nil, fmt.Errorf("dial %s failed: %w", host, err)
   157  	}
   158  
   159  	conn, err := p.PacketConn(ctx, host)
   160  	if err != nil {
   161  		return nil, fmt.Errorf("dial %s failed: %w", host, err)
   162  	}
   163  
   164  	return conn, nil
   165  }
   166  
   167  func (s *Shunt) Dispatch(ctx context.Context, host netapi.Address) (netapi.Address, error) {
   168  	_, addr := s.dispatch(ctx, bypass.Mode_bypass, host)
   169  	return addr, nil
   170  }
   171  
   172  func (s *Shunt) Search(ctx context.Context, addr netapi.Address) bypass.ModeEnum {
   173  	mode, ok := s.customMapper.Search(ctx, addr)
   174  	if ok {
   175  		return mode
   176  	}
   177  
   178  	mode, ok = s.mapper.Search(ctx, addr)
   179  	if ok {
   180  		return mode
   181  	}
   182  
   183  	return bypass.Mode_proxy
   184  }
   185  
   186  func (s *Shunt) dispatch(ctx context.Context, networkMode bypass.Mode, host netapi.Address) (bypass.ModeEnum, netapi.Address) {
   187  	var mode bypass.ModeEnum = bypass.Mode_bypass
   188  
   189  	process := s.DumpProcess(ctx, host)
   190  	if process != "" {
   191  		m, ok := s.processMapper.Load(process)
   192  		if ok {
   193  			mode = m
   194  		}
   195  	}
   196  
   197  	// get mode from upstream specified
   198  	store := netapi.StoreFromContext(ctx)
   199  
   200  	if mode.Mode() == bypass.Mode_bypass {
   201  		mode = netapi.GetDefault(
   202  			ctx,
   203  			ForceModeKey{},
   204  			networkMode, // get mode from network(tcp/udp) rule
   205  		)
   206  	}
   207  
   208  	if mode.Mode() == bypass.Mode_bypass {
   209  		// get mode from bypass rule
   210  		host.SetResolver(s.r.Get(""))
   211  		mode = s.Search(ctx, host)
   212  		if mode.GetResolveStrategy() == bypass.ResolveStrategy_prefer_ipv6 {
   213  			host.PreferIPv6(true)
   214  		}
   215  	}
   216  
   217  	store.Add(modeMarkKey{}, mode.Mode())
   218  	host.SetResolver(s.r.Get(mode.Mode().String()))
   219  
   220  	if s.resolveDomain && host.IsFqdn() && mode == bypass.Mode_proxy {
   221  		// resolve proxy domain if resolveRemoteDomain enabled
   222  		ip, err := host.IP(ctx)
   223  		if err == nil {
   224  			store.Add(DOMAIN_MARK_KEY{}, host.String())
   225  			host = host.OverrideHostname(ip.String())
   226  			store.Add(IP_MARK_KEY{}, host.String())
   227  		} else {
   228  			log.Warn("resolve remote domain failed", "err", err)
   229  		}
   230  	}
   231  
   232  	return mode, host
   233  }
   234  
   235  func (s *Shunt) Resolver(ctx context.Context, domain string) netapi.Resolver {
   236  	host := netapi.ParseAddressPort(0, domain, netapi.EmptyPort)
   237  	host.SetResolver(trie.SkipResolver)
   238  	return s.r.Get(s.Search(ctx, host).Mode().String())
   239  }
   240  
   241  func (f *Shunt) LookupIP(ctx context.Context, domain string, opts ...func(*netapi.LookupIPOption)) ([]net.IP, error) {
   242  	return f.Resolver(ctx, domain).LookupIP(ctx, domain, opts...)
   243  }
   244  
   245  func (f *Shunt) Raw(ctx context.Context, req dnsmessage.Question) (dnsmessage.Message, error) {
   246  	return f.Resolver(ctx, strings.TrimSuffix(req.Name.String(), ".")).Raw(ctx, req)
   247  }
   248  
   249  func (f *Shunt) Close() error { return nil }
   250  
   251  func (c *Shunt) DumpProcess(ctx context.Context, addr netapi.Address) (s string) {
   252  	if c.ProcessDumper == nil {
   253  		return
   254  	}
   255  
   256  	store := netapi.StoreFromContext(ctx)
   257  
   258  	source, ok := store.Get(netapi.SourceKey{})
   259  	if !ok {
   260  		return
   261  	}
   262  
   263  	var dst []any
   264  	ds, ok := store.Get(netapi.InboundKey{})
   265  	if ok {
   266  		dst = append(dst, ds)
   267  	}
   268  	ds, ok = store.Get(netapi.DestinationKey{})
   269  	if ok {
   270  		dst = append(dst, ds)
   271  	}
   272  
   273  	if len(dst) == 0 {
   274  		return
   275  	}
   276  
   277  	sourceAddr, err := convert.ToProxyAddress(addr.NetworkType(), source)
   278  	if err != nil {
   279  		return
   280  	}
   281  
   282  	for _, d := range dst {
   283  		dst, err := convert.ToProxyAddress(addr.NetworkType(), d)
   284  		if err != nil {
   285  			continue
   286  		}
   287  
   288  		process, err := c.ProcessDumper.ProcessName(addr.Network(), sourceAddr, dst)
   289  		if err != nil {
   290  			log.Warn("get process name failed", "err", err)
   291  			continue
   292  		}
   293  
   294  		store.Add("Process", process)
   295  		return process
   296  	}
   297  
   298  	return ""
   299  }