github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/components/shunt/shunt.go (about) 1 package shunt 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "os" 8 "slices" 9 "strings" 10 "sync" 11 12 "github.com/Asutorufa/yuhaiin/pkg/log" 13 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 14 "github.com/Asutorufa/yuhaiin/pkg/net/trie" 15 pc "github.com/Asutorufa/yuhaiin/pkg/protos/config" 16 "github.com/Asutorufa/yuhaiin/pkg/protos/config/bypass" 17 "github.com/Asutorufa/yuhaiin/pkg/utils/convert" 18 "github.com/Asutorufa/yuhaiin/pkg/utils/syncmap" 19 "golang.org/x/exp/maps" 20 "golang.org/x/net/dns/dnsmessage" 21 "google.golang.org/protobuf/proto" 22 ) 23 24 type modeMarkKey struct{} 25 26 func (modeMarkKey) String() string { return "MODE" } 27 28 type DOMAIN_MARK_KEY struct{} 29 30 type IP_MARK_KEY struct{} 31 32 func (IP_MARK_KEY) String() string { return "IP" } 33 34 type ForceModeKey struct{} 35 36 type Shunt struct { 37 resolveDomain bool 38 modifiedTime int64 39 40 config *bypass.BypassConfig 41 mapper *trie.Trie[bypass.ModeEnum] 42 customMapper *trie.Trie[bypass.ModeEnum] 43 44 processMapper syncmap.SyncMap[string, bypass.ModeEnum] 45 ProcessDumper netapi.ProcessDumper 46 47 mu sync.RWMutex 48 49 r Resolver 50 d Dialer 51 52 tags map[string]struct{} 53 } 54 55 type Resolver interface { 56 Get(str string) netapi.Resolver 57 } 58 type Dialer interface { 59 Get(ctx context.Context, network string, str string, tag string) (netapi.Proxy, error) 60 } 61 62 func NewShunt(d Dialer, r Resolver, ProcessDumper netapi.ProcessDumper) *Shunt { 63 return &Shunt{ 64 mapper: trie.NewTrie[bypass.ModeEnum](), 65 customMapper: trie.NewTrie[bypass.ModeEnum](), 66 config: &bypass.BypassConfig{ 67 Tcp: bypass.Mode_bypass, 68 Udp: bypass.Mode_bypass, 69 }, 70 r: r, 71 d: d, 72 ProcessDumper: ProcessDumper, 73 tags: make(map[string]struct{}), 74 } 75 } 76 77 func (s *Shunt) Update(c *pc.Setting) { 78 s.mu.Lock() 79 defer s.mu.Unlock() 80 81 s.resolveDomain = c.Dns.ResolveRemoteDomain 82 83 if !slices.EqualFunc( 84 s.config.CustomRuleV3, 85 c.Bypass.CustomRuleV3, 86 func(mc1, mc2 *bypass.ModeConfig) bool { return proto.Equal(mc1, mc2) }, 87 ) { 88 s.customMapper.Clear() //nolint:errcheck 89 s.processMapper = syncmap.SyncMap[string, bypass.ModeEnum]{} 90 91 for _, v := range c.Bypass.CustomRuleV3 { 92 mark := v.ToModeEnum() 93 94 if mark.GetTag() != "" { 95 s.tags[mark.GetTag()] = struct{}{} 96 } 97 98 for _, hostname := range v.Hostname { 99 if strings.HasPrefix(hostname, "process:") { 100 s.processMapper.Store(hostname[8:], mark) 101 } else { 102 s.customMapper.Insert(hostname, mark) 103 } 104 } 105 } 106 } 107 108 modifiedTime := s.modifiedTime 109 if stat, err := os.Stat(c.Bypass.BypassFile); err == nil { 110 modifiedTime = stat.ModTime().Unix() 111 } 112 113 if s.config.BypassFile != c.Bypass.BypassFile || s.modifiedTime != modifiedTime { 114 s.mapper.Clear() //nolint:errcheck 115 s.tags = make(map[string]struct{}) 116 s.modifiedTime = modifiedTime 117 rangeRule(c.Bypass.BypassFile, func(s1 string, s2 bypass.ModeEnum) { 118 if strings.HasPrefix(s1, "process:") { 119 s.processMapper.Store(s1[8:], s2.Mode()) 120 } else { 121 s.mapper.Insert(s1, s2) 122 } 123 124 if s2.GetTag() != "" { 125 s.tags[s2.GetTag()] = struct{}{} 126 } 127 }) 128 } 129 130 s.config = c.Bypass 131 } 132 133 func (s *Shunt) Tags() []string { return maps.Keys(s.tags) } 134 135 func (s *Shunt) Conn(ctx context.Context, host netapi.Address) (net.Conn, error) { 136 mode, host := s.dispatch(ctx, s.config.Tcp, host) 137 138 p, err := s.d.Get(ctx, "tcp", mode.Mode().String(), mode.GetTag()) 139 if err != nil { 140 return nil, fmt.Errorf("dial %s failed: %w", host, err) 141 } 142 143 conn, err := p.Conn(ctx, host) 144 if err != nil { 145 return nil, fmt.Errorf("dial %s failed: %w", host, err) 146 } 147 148 return conn, nil 149 } 150 151 func (s *Shunt) PacketConn(ctx context.Context, host netapi.Address) (net.PacketConn, error) { 152 mode, host := s.dispatch(ctx, s.config.Udp, host) 153 154 p, err := s.d.Get(ctx, "udp", mode.Mode().String(), mode.GetTag()) 155 if err != nil { 156 return nil, fmt.Errorf("dial %s failed: %w", host, err) 157 } 158 159 conn, err := p.PacketConn(ctx, host) 160 if err != nil { 161 return nil, fmt.Errorf("dial %s failed: %w", host, err) 162 } 163 164 return conn, nil 165 } 166 167 func (s *Shunt) Dispatch(ctx context.Context, host netapi.Address) (netapi.Address, error) { 168 _, addr := s.dispatch(ctx, bypass.Mode_bypass, host) 169 return addr, nil 170 } 171 172 func (s *Shunt) Search(ctx context.Context, addr netapi.Address) bypass.ModeEnum { 173 mode, ok := s.customMapper.Search(ctx, addr) 174 if ok { 175 return mode 176 } 177 178 mode, ok = s.mapper.Search(ctx, addr) 179 if ok { 180 return mode 181 } 182 183 return bypass.Mode_proxy 184 } 185 186 func (s *Shunt) dispatch(ctx context.Context, networkMode bypass.Mode, host netapi.Address) (bypass.ModeEnum, netapi.Address) { 187 var mode bypass.ModeEnum = bypass.Mode_bypass 188 189 process := s.DumpProcess(ctx, host) 190 if process != "" { 191 m, ok := s.processMapper.Load(process) 192 if ok { 193 mode = m 194 } 195 } 196 197 // get mode from upstream specified 198 store := netapi.StoreFromContext(ctx) 199 200 if mode.Mode() == bypass.Mode_bypass { 201 mode = netapi.GetDefault( 202 ctx, 203 ForceModeKey{}, 204 networkMode, // get mode from network(tcp/udp) rule 205 ) 206 } 207 208 if mode.Mode() == bypass.Mode_bypass { 209 // get mode from bypass rule 210 host.SetResolver(s.r.Get("")) 211 mode = s.Search(ctx, host) 212 if mode.GetResolveStrategy() == bypass.ResolveStrategy_prefer_ipv6 { 213 host.PreferIPv6(true) 214 } 215 } 216 217 store.Add(modeMarkKey{}, mode.Mode()) 218 host.SetResolver(s.r.Get(mode.Mode().String())) 219 220 if s.resolveDomain && host.IsFqdn() && mode == bypass.Mode_proxy { 221 // resolve proxy domain if resolveRemoteDomain enabled 222 ip, err := host.IP(ctx) 223 if err == nil { 224 store.Add(DOMAIN_MARK_KEY{}, host.String()) 225 host = host.OverrideHostname(ip.String()) 226 store.Add(IP_MARK_KEY{}, host.String()) 227 } else { 228 log.Warn("resolve remote domain failed", "err", err) 229 } 230 } 231 232 return mode, host 233 } 234 235 func (s *Shunt) Resolver(ctx context.Context, domain string) netapi.Resolver { 236 host := netapi.ParseAddressPort(0, domain, netapi.EmptyPort) 237 host.SetResolver(trie.SkipResolver) 238 return s.r.Get(s.Search(ctx, host).Mode().String()) 239 } 240 241 func (f *Shunt) LookupIP(ctx context.Context, domain string, opts ...func(*netapi.LookupIPOption)) ([]net.IP, error) { 242 return f.Resolver(ctx, domain).LookupIP(ctx, domain, opts...) 243 } 244 245 func (f *Shunt) Raw(ctx context.Context, req dnsmessage.Question) (dnsmessage.Message, error) { 246 return f.Resolver(ctx, strings.TrimSuffix(req.Name.String(), ".")).Raw(ctx, req) 247 } 248 249 func (f *Shunt) Close() error { return nil } 250 251 func (c *Shunt) DumpProcess(ctx context.Context, addr netapi.Address) (s string) { 252 if c.ProcessDumper == nil { 253 return 254 } 255 256 store := netapi.StoreFromContext(ctx) 257 258 source, ok := store.Get(netapi.SourceKey{}) 259 if !ok { 260 return 261 } 262 263 var dst []any 264 ds, ok := store.Get(netapi.InboundKey{}) 265 if ok { 266 dst = append(dst, ds) 267 } 268 ds, ok = store.Get(netapi.DestinationKey{}) 269 if ok { 270 dst = append(dst, ds) 271 } 272 273 if len(dst) == 0 { 274 return 275 } 276 277 sourceAddr, err := convert.ToProxyAddress(addr.NetworkType(), source) 278 if err != nil { 279 return 280 } 281 282 for _, d := range dst { 283 dst, err := convert.ToProxyAddress(addr.NetworkType(), d) 284 if err != nil { 285 continue 286 } 287 288 process, err := c.ProcessDumper.ProcessName(addr.Network(), sourceAddr, dst) 289 if err != nil { 290 log.Warn("get process name failed", "err", err) 291 continue 292 } 293 294 store.Add("Process", process) 295 return process 296 } 297 298 return "" 299 }