github.com/ipfans/trojan-go@v0.11.0/tunnel/router/client.go (about) 1 package router 2 3 import ( 4 "context" 5 "net" 6 "regexp" 7 "runtime" 8 "strconv" 9 "strings" 10 11 v2router "github.com/v2fly/v2ray-core/v4/app/router" 12 13 "github.com/ipfans/trojan-go/common" 14 "github.com/ipfans/trojan-go/common/geodata" 15 "github.com/ipfans/trojan-go/config" 16 "github.com/ipfans/trojan-go/log" 17 "github.com/ipfans/trojan-go/tunnel" 18 "github.com/ipfans/trojan-go/tunnel/freedom" 19 "github.com/ipfans/trojan-go/tunnel/transport" 20 ) 21 22 const ( 23 Block = 0 24 Bypass = 1 25 Proxy = 2 26 ) 27 28 const ( 29 AsIs = 0 30 IPIfNonMatch = 1 31 IPOnDemand = 2 32 ) 33 34 const MaxPacketSize = 1024 * 8 35 36 func matchDomain(list []*v2router.Domain, target string) bool { 37 for _, d := range list { 38 switch d.GetType() { 39 case v2router.Domain_Full: 40 domain := d.GetValue() 41 if domain == target { 42 log.Tracef("domain %s hit domain(full) rule: %s", target, domain) 43 return true 44 } 45 case v2router.Domain_Domain: 46 domain := d.GetValue() 47 if strings.HasSuffix(target, domain) { 48 idx := strings.Index(target, domain) 49 if idx == 0 || target[idx-1] == '.' { 50 log.Tracef("domain %s hit domain rule: %s", target, domain) 51 return true 52 } 53 } 54 case v2router.Domain_Plain: 55 // keyword 56 if strings.Contains(target, d.GetValue()) { 57 log.Tracef("domain %s hit keyword rule: %s", target, d.GetValue()) 58 return true 59 } 60 case v2router.Domain_Regex: 61 matched, err := regexp.Match(d.GetValue(), []byte(target)) 62 if err != nil { 63 log.Error("invalid regex", d.GetValue()) 64 return false 65 } 66 if matched { 67 log.Tracef("domain %s hit regex rule: %s", target, d.GetValue()) 68 return true 69 } 70 default: 71 log.Debug("unknown rule type:", d.GetType().String()) 72 } 73 } 74 return false 75 } 76 77 func matchIP(list []*v2router.CIDR, target net.IP) bool { 78 isIPv6 := true 79 len := net.IPv6len 80 if target.To4() != nil { 81 len = net.IPv4len 82 isIPv6 = false 83 } 84 for _, c := range list { 85 n := int(c.GetPrefix()) 86 mask := net.CIDRMask(n, 8*len) 87 cidrIP := net.IP(c.GetIp()) 88 if cidrIP.To4() != nil { // IPv4 CIDR 89 if isIPv6 { 90 continue 91 } 92 } else { // IPv6 CIDR 93 if !isIPv6 { 94 continue 95 } 96 } 97 subnet := &net.IPNet{IP: cidrIP.Mask(mask), Mask: mask} 98 if subnet.Contains(target) { 99 return true 100 } 101 } 102 return false 103 } 104 105 func newIPAddress(address *tunnel.Address) (*tunnel.Address, error) { 106 ip, err := address.ResolveIP() 107 if err != nil { 108 return nil, common.NewError("router failed to resolve ip").Base(err) 109 } 110 newAddress := &tunnel.Address{ 111 IP: ip, 112 Port: address.Port, 113 } 114 if ip.To4() != nil { 115 newAddress.AddressType = tunnel.IPv4 116 } else { 117 newAddress.AddressType = tunnel.IPv6 118 } 119 return newAddress, nil 120 } 121 122 type Client struct { 123 domains [3][]*v2router.Domain 124 cidrs [3][]*v2router.CIDR 125 defaultPolicy int 126 domainStrategy int 127 underlay tunnel.Client 128 direct *freedom.Client 129 ctx context.Context 130 cancel context.CancelFunc 131 } 132 133 func (c *Client) Route(address *tunnel.Address) int { 134 if address.AddressType == tunnel.DomainName { 135 if c.domainStrategy == IPOnDemand { 136 resolvedIP, err := newIPAddress(address) 137 if err == nil { 138 for i := Block; i <= Proxy; i++ { 139 if matchIP(c.cidrs[i], resolvedIP.IP) { 140 return i 141 } 142 } 143 } 144 } 145 for i := Block; i <= Proxy; i++ { 146 if matchDomain(c.domains[i], address.DomainName) { 147 return i 148 } 149 } 150 if c.domainStrategy == IPIfNonMatch { 151 resolvedIP, err := newIPAddress(address) 152 if err == nil { 153 for i := Block; i <= Proxy; i++ { 154 if matchIP(c.cidrs[i], resolvedIP.IP) { 155 return i 156 } 157 } 158 } 159 } 160 } else { 161 for i := Block; i <= Proxy; i++ { 162 if matchIP(c.cidrs[i], address.IP) { 163 return i 164 } 165 } 166 } 167 return c.defaultPolicy 168 } 169 170 func (c *Client) DialConn(address *tunnel.Address, overlay tunnel.Tunnel) (tunnel.Conn, error) { 171 policy := c.Route(address) 172 switch policy { 173 case Proxy: 174 return c.underlay.DialConn(address, overlay) 175 case Block: 176 return nil, common.NewError("router blocked address: " + address.String()) 177 case Bypass: 178 conn, err := c.direct.DialConn(address, &Tunnel{}) 179 if err != nil { 180 return nil, common.NewError("router dial error").Base(err) 181 } 182 return &transport.Conn{ 183 Conn: conn, 184 }, nil 185 } 186 panic("unknown policy") 187 } 188 189 func (c *Client) DialPacket(overlay tunnel.Tunnel) (tunnel.PacketConn, error) { 190 directConn, err := net.ListenPacket("udp", "") 191 if err != nil { 192 return nil, common.NewError("router failed to dial udp (direct)").Base(err) 193 } 194 proxy, err := c.underlay.DialPacket(overlay) 195 if err != nil { 196 return nil, common.NewError("router failed to dial udp (proxy)").Base(err) 197 } 198 ctx, cancel := context.WithCancel(c.ctx) 199 conn := &PacketConn{ 200 Client: c, 201 PacketConn: directConn, 202 proxy: proxy, 203 cancel: cancel, 204 ctx: ctx, 205 packetChan: make(chan *packetInfo, 16), 206 } 207 go conn.packetLoop() 208 return conn, nil 209 } 210 211 func (c *Client) Close() error { 212 c.cancel() 213 return c.underlay.Close() 214 } 215 216 type codeInfo struct { 217 code string 218 strategy int 219 } 220 221 func loadCode(cfg *Config, prefix string) []codeInfo { 222 codes := []codeInfo{} 223 for _, s := range cfg.Router.Proxy { 224 if strings.HasPrefix(s, prefix) { 225 if left := s[len(prefix):]; len(left) > 0 { 226 codes = append(codes, codeInfo{ 227 code: left, 228 strategy: Proxy, 229 }) 230 } else { 231 log.Warn("invalid empty rule:", s) 232 } 233 } 234 } 235 for _, s := range cfg.Router.Bypass { 236 if strings.HasPrefix(s, prefix) { 237 if left := s[len(prefix):]; len(left) > 0 { 238 codes = append(codes, codeInfo{ 239 code: left, 240 strategy: Bypass, 241 }) 242 } else { 243 log.Warn("invalid empty rule:", s) 244 } 245 } 246 } 247 for _, s := range cfg.Router.Block { 248 if strings.HasPrefix(s, prefix) { 249 if left := s[len(prefix):]; len(left) > 0 { 250 codes = append(codes, codeInfo{ 251 code: left, 252 strategy: Block, 253 }) 254 } else { 255 log.Warn("invalid empty rule:", s) 256 } 257 } 258 } 259 return codes 260 } 261 262 func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) { 263 m1 := runtime.MemStats{} 264 m2 := runtime.MemStats{} 265 m3 := runtime.MemStats{} 266 m4 := runtime.MemStats{} 267 268 cfg := config.FromContext(ctx, Name).(*Config) 269 var cancel context.CancelFunc 270 ctx, cancel = context.WithCancel(ctx) 271 272 direct, err := freedom.NewClient(ctx, nil) 273 if err != nil { 274 cancel() 275 return nil, common.NewError("router failed to initialize raw client").Base(err) 276 } 277 278 client := &Client{ 279 domains: [3][]*v2router.Domain{}, 280 cidrs: [3][]*v2router.CIDR{}, 281 underlay: underlay, 282 direct: direct, 283 ctx: ctx, 284 cancel: cancel, 285 } 286 switch strings.ToLower(cfg.Router.DomainStrategy) { 287 case "as_is", "as-is", "asis": 288 client.domainStrategy = AsIs 289 case "ip_if_non_match", "ip-if-non-match", "ipifnonmatch": 290 client.domainStrategy = IPIfNonMatch 291 case "ip_on_demand", "ip-on-demand", "ipondemand": 292 client.domainStrategy = IPOnDemand 293 default: 294 return nil, common.NewError("unknown strategy: " + cfg.Router.DomainStrategy) 295 } 296 297 switch strings.ToLower(cfg.Router.DefaultPolicy) { 298 case "proxy": 299 client.defaultPolicy = Proxy 300 case "bypass": 301 client.defaultPolicy = Bypass 302 case "block": 303 client.defaultPolicy = Block 304 default: 305 return nil, common.NewError("unknown strategy: " + cfg.Router.DomainStrategy) 306 } 307 308 runtime.ReadMemStats(&m1) 309 310 geodataLoader := geodata.NewGeodataLoader() 311 312 ipCode := loadCode(cfg, "geoip:") 313 for _, c := range ipCode { 314 code := c.code 315 cidrs, err := geodataLoader.LoadIP(cfg.Router.GeoIPFilename, code) 316 if err != nil { 317 log.Error(err) 318 } else { 319 log.Infof("geoip:%s loaded", code) 320 client.cidrs[c.strategy] = append(client.cidrs[c.strategy], cidrs...) 321 } 322 } 323 324 runtime.ReadMemStats(&m2) 325 326 siteCode := loadCode(cfg, "geosite:") 327 for _, c := range siteCode { 328 code := c.code 329 attrWanted := "" 330 // Test if user wants domains that have an attribute 331 if attrIdx := strings.Index(code, "@"); attrIdx > 0 { 332 if !strings.HasSuffix(code, "@") { 333 code = c.code[:attrIdx] 334 attrWanted = c.code[attrIdx+1:] 335 } else { // "geosite:google@" is invalid 336 log.Warnf("geosite:%s invalid", code) 337 continue 338 } 339 } else if attrIdx == 0 { // "geosite:@cn" is invalid 340 log.Warnf("geosite:%s invalid", code) 341 continue 342 } 343 344 domainList, err := geodataLoader.LoadSite(cfg.Router.GeoSiteFilename, code) 345 if err != nil { 346 log.Error(err) 347 } else { 348 found := false 349 if attrWanted != "" { 350 for _, domain := range domainList { 351 for _, attr := range domain.GetAttribute() { 352 if strings.EqualFold(attrWanted, attr.GetKey()) { 353 client.domains[c.strategy] = append(client.domains[c.strategy], domain) 354 found = true 355 } 356 } 357 } 358 } else { 359 client.domains[c.strategy] = append(client.domains[c.strategy], domainList...) 360 found = true 361 } 362 if found { 363 log.Infof("geosite:%s loaded", c.code) 364 } else { 365 log.Errorf("geosite:%s not found", c.code) 366 } 367 } 368 } 369 370 runtime.ReadMemStats(&m3) 371 372 domainInfo := loadCode(cfg, "domain:") 373 for _, info := range domainInfo { 374 client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{ 375 Type: v2router.Domain_Domain, 376 Value: strings.ToLower(info.code), 377 Attribute: nil, 378 }) 379 } 380 381 keywordInfo := loadCode(cfg, "keyword:") 382 for _, info := range keywordInfo { 383 client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{ 384 Type: v2router.Domain_Plain, 385 Value: strings.ToLower(info.code), 386 Attribute: nil, 387 }) 388 } 389 390 regexInfo := loadCode(cfg, "regex:") 391 for _, info := range regexInfo { 392 if _, err := regexp.Compile(info.code); err != nil { 393 return nil, common.NewError("invalid regular expression: " + info.code).Base(err) 394 } 395 client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{ 396 Type: v2router.Domain_Regex, 397 Value: info.code, 398 Attribute: nil, 399 }) 400 } 401 402 // Just for compatibility with V2Ray rule type `regexp` 403 regexpInfo := loadCode(cfg, "regexp:") 404 for _, info := range regexpInfo { 405 if _, err := regexp.Compile(info.code); err != nil { 406 return nil, common.NewError("invalid regular expression: " + info.code).Base(err) 407 } 408 client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{ 409 Type: v2router.Domain_Regex, 410 Value: info.code, 411 Attribute: nil, 412 }) 413 } 414 415 fullInfo := loadCode(cfg, "full:") 416 for _, info := range fullInfo { 417 client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{ 418 Type: v2router.Domain_Full, 419 Value: strings.ToLower(info.code), 420 Attribute: nil, 421 }) 422 } 423 424 cidrInfo := loadCode(cfg, "cidr:") 425 for _, info := range cidrInfo { 426 tmp := strings.Split(info.code, "/") 427 if len(tmp) != 2 { 428 return nil, common.NewError("invalid cidr: " + info.code) 429 } 430 ip := net.ParseIP(tmp[0]) 431 if ip == nil { 432 return nil, common.NewError("invalid cidr ip: " + info.code) 433 } 434 prefix, err := strconv.ParseInt(tmp[1], 10, 32) 435 if err != nil { 436 return nil, common.NewError("invalid prefix").Base(err) 437 } 438 client.cidrs[info.strategy] = append(client.cidrs[info.strategy], &v2router.CIDR{ 439 Ip: ip, 440 Prefix: uint32(prefix), 441 }) 442 } 443 444 log.Info("router client created") 445 446 runtime.ReadMemStats(&m4) 447 448 log.Debugf("GeoIP rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m2.Alloc-m1.Alloc), common.HumanFriendlyTraffic(m2.TotalAlloc-m1.TotalAlloc)) 449 log.Debugf("GeoSite rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m3.Alloc-m2.Alloc), common.HumanFriendlyTraffic(m3.TotalAlloc-m2.TotalAlloc)) 450 log.Debugf("Plaintext rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m4.Alloc-m3.Alloc), common.HumanFriendlyTraffic(m4.TotalAlloc-m3.TotalAlloc)) 451 log.Debugf("Total(router) -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m4.Alloc-m1.Alloc), common.HumanFriendlyTraffic(m4.TotalAlloc-m1.TotalAlloc)) 452 453 return client, nil 454 }