github.com/ipfans/trojan-go@v0.11.0/tunnel/router/client.go (about)

     1  package router
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"regexp"
     7  	"runtime"
     8  	"strconv"
     9  	"strings"
    10  
    11  	v2router "github.com/v2fly/v2ray-core/v4/app/router"
    12  
    13  	"github.com/ipfans/trojan-go/common"
    14  	"github.com/ipfans/trojan-go/common/geodata"
    15  	"github.com/ipfans/trojan-go/config"
    16  	"github.com/ipfans/trojan-go/log"
    17  	"github.com/ipfans/trojan-go/tunnel"
    18  	"github.com/ipfans/trojan-go/tunnel/freedom"
    19  	"github.com/ipfans/trojan-go/tunnel/transport"
    20  )
    21  
    22  const (
    23  	Block  = 0
    24  	Bypass = 1
    25  	Proxy  = 2
    26  )
    27  
    28  const (
    29  	AsIs         = 0
    30  	IPIfNonMatch = 1
    31  	IPOnDemand   = 2
    32  )
    33  
    34  const MaxPacketSize = 1024 * 8
    35  
    36  func matchDomain(list []*v2router.Domain, target string) bool {
    37  	for _, d := range list {
    38  		switch d.GetType() {
    39  		case v2router.Domain_Full:
    40  			domain := d.GetValue()
    41  			if domain == target {
    42  				log.Tracef("domain %s hit domain(full) rule: %s", target, domain)
    43  				return true
    44  			}
    45  		case v2router.Domain_Domain:
    46  			domain := d.GetValue()
    47  			if strings.HasSuffix(target, domain) {
    48  				idx := strings.Index(target, domain)
    49  				if idx == 0 || target[idx-1] == '.' {
    50  					log.Tracef("domain %s hit domain rule: %s", target, domain)
    51  					return true
    52  				}
    53  			}
    54  		case v2router.Domain_Plain:
    55  			// keyword
    56  			if strings.Contains(target, d.GetValue()) {
    57  				log.Tracef("domain %s hit keyword rule: %s", target, d.GetValue())
    58  				return true
    59  			}
    60  		case v2router.Domain_Regex:
    61  			matched, err := regexp.Match(d.GetValue(), []byte(target))
    62  			if err != nil {
    63  				log.Error("invalid regex", d.GetValue())
    64  				return false
    65  			}
    66  			if matched {
    67  				log.Tracef("domain %s hit regex rule: %s", target, d.GetValue())
    68  				return true
    69  			}
    70  		default:
    71  			log.Debug("unknown rule type:", d.GetType().String())
    72  		}
    73  	}
    74  	return false
    75  }
    76  
    77  func matchIP(list []*v2router.CIDR, target net.IP) bool {
    78  	isIPv6 := true
    79  	len := net.IPv6len
    80  	if target.To4() != nil {
    81  		len = net.IPv4len
    82  		isIPv6 = false
    83  	}
    84  	for _, c := range list {
    85  		n := int(c.GetPrefix())
    86  		mask := net.CIDRMask(n, 8*len)
    87  		cidrIP := net.IP(c.GetIp())
    88  		if cidrIP.To4() != nil { // IPv4 CIDR
    89  			if isIPv6 {
    90  				continue
    91  			}
    92  		} else { // IPv6 CIDR
    93  			if !isIPv6 {
    94  				continue
    95  			}
    96  		}
    97  		subnet := &net.IPNet{IP: cidrIP.Mask(mask), Mask: mask}
    98  		if subnet.Contains(target) {
    99  			return true
   100  		}
   101  	}
   102  	return false
   103  }
   104  
   105  func newIPAddress(address *tunnel.Address) (*tunnel.Address, error) {
   106  	ip, err := address.ResolveIP()
   107  	if err != nil {
   108  		return nil, common.NewError("router failed to resolve ip").Base(err)
   109  	}
   110  	newAddress := &tunnel.Address{
   111  		IP:   ip,
   112  		Port: address.Port,
   113  	}
   114  	if ip.To4() != nil {
   115  		newAddress.AddressType = tunnel.IPv4
   116  	} else {
   117  		newAddress.AddressType = tunnel.IPv6
   118  	}
   119  	return newAddress, nil
   120  }
   121  
   122  type Client struct {
   123  	domains        [3][]*v2router.Domain
   124  	cidrs          [3][]*v2router.CIDR
   125  	defaultPolicy  int
   126  	domainStrategy int
   127  	underlay       tunnel.Client
   128  	direct         *freedom.Client
   129  	ctx            context.Context
   130  	cancel         context.CancelFunc
   131  }
   132  
   133  func (c *Client) Route(address *tunnel.Address) int {
   134  	if address.AddressType == tunnel.DomainName {
   135  		if c.domainStrategy == IPOnDemand {
   136  			resolvedIP, err := newIPAddress(address)
   137  			if err == nil {
   138  				for i := Block; i <= Proxy; i++ {
   139  					if matchIP(c.cidrs[i], resolvedIP.IP) {
   140  						return i
   141  					}
   142  				}
   143  			}
   144  		}
   145  		for i := Block; i <= Proxy; i++ {
   146  			if matchDomain(c.domains[i], address.DomainName) {
   147  				return i
   148  			}
   149  		}
   150  		if c.domainStrategy == IPIfNonMatch {
   151  			resolvedIP, err := newIPAddress(address)
   152  			if err == nil {
   153  				for i := Block; i <= Proxy; i++ {
   154  					if matchIP(c.cidrs[i], resolvedIP.IP) {
   155  						return i
   156  					}
   157  				}
   158  			}
   159  		}
   160  	} else {
   161  		for i := Block; i <= Proxy; i++ {
   162  			if matchIP(c.cidrs[i], address.IP) {
   163  				return i
   164  			}
   165  		}
   166  	}
   167  	return c.defaultPolicy
   168  }
   169  
   170  func (c *Client) DialConn(address *tunnel.Address, overlay tunnel.Tunnel) (tunnel.Conn, error) {
   171  	policy := c.Route(address)
   172  	switch policy {
   173  	case Proxy:
   174  		return c.underlay.DialConn(address, overlay)
   175  	case Block:
   176  		return nil, common.NewError("router blocked address: " + address.String())
   177  	case Bypass:
   178  		conn, err := c.direct.DialConn(address, &Tunnel{})
   179  		if err != nil {
   180  			return nil, common.NewError("router dial error").Base(err)
   181  		}
   182  		return &transport.Conn{
   183  			Conn: conn,
   184  		}, nil
   185  	}
   186  	panic("unknown policy")
   187  }
   188  
   189  func (c *Client) DialPacket(overlay tunnel.Tunnel) (tunnel.PacketConn, error) {
   190  	directConn, err := net.ListenPacket("udp", "")
   191  	if err != nil {
   192  		return nil, common.NewError("router failed to dial udp (direct)").Base(err)
   193  	}
   194  	proxy, err := c.underlay.DialPacket(overlay)
   195  	if err != nil {
   196  		return nil, common.NewError("router failed to dial udp (proxy)").Base(err)
   197  	}
   198  	ctx, cancel := context.WithCancel(c.ctx)
   199  	conn := &PacketConn{
   200  		Client:     c,
   201  		PacketConn: directConn,
   202  		proxy:      proxy,
   203  		cancel:     cancel,
   204  		ctx:        ctx,
   205  		packetChan: make(chan *packetInfo, 16),
   206  	}
   207  	go conn.packetLoop()
   208  	return conn, nil
   209  }
   210  
   211  func (c *Client) Close() error {
   212  	c.cancel()
   213  	return c.underlay.Close()
   214  }
   215  
   216  type codeInfo struct {
   217  	code     string
   218  	strategy int
   219  }
   220  
   221  func loadCode(cfg *Config, prefix string) []codeInfo {
   222  	codes := []codeInfo{}
   223  	for _, s := range cfg.Router.Proxy {
   224  		if strings.HasPrefix(s, prefix) {
   225  			if left := s[len(prefix):]; len(left) > 0 {
   226  				codes = append(codes, codeInfo{
   227  					code:     left,
   228  					strategy: Proxy,
   229  				})
   230  			} else {
   231  				log.Warn("invalid empty rule:", s)
   232  			}
   233  		}
   234  	}
   235  	for _, s := range cfg.Router.Bypass {
   236  		if strings.HasPrefix(s, prefix) {
   237  			if left := s[len(prefix):]; len(left) > 0 {
   238  				codes = append(codes, codeInfo{
   239  					code:     left,
   240  					strategy: Bypass,
   241  				})
   242  			} else {
   243  				log.Warn("invalid empty rule:", s)
   244  			}
   245  		}
   246  	}
   247  	for _, s := range cfg.Router.Block {
   248  		if strings.HasPrefix(s, prefix) {
   249  			if left := s[len(prefix):]; len(left) > 0 {
   250  				codes = append(codes, codeInfo{
   251  					code:     left,
   252  					strategy: Block,
   253  				})
   254  			} else {
   255  				log.Warn("invalid empty rule:", s)
   256  			}
   257  		}
   258  	}
   259  	return codes
   260  }
   261  
   262  func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) {
   263  	m1 := runtime.MemStats{}
   264  	m2 := runtime.MemStats{}
   265  	m3 := runtime.MemStats{}
   266  	m4 := runtime.MemStats{}
   267  
   268  	cfg := config.FromContext(ctx, Name).(*Config)
   269  	var cancel context.CancelFunc
   270  	ctx, cancel = context.WithCancel(ctx)
   271  
   272  	direct, err := freedom.NewClient(ctx, nil)
   273  	if err != nil {
   274  		cancel()
   275  		return nil, common.NewError("router failed to initialize raw client").Base(err)
   276  	}
   277  
   278  	client := &Client{
   279  		domains:  [3][]*v2router.Domain{},
   280  		cidrs:    [3][]*v2router.CIDR{},
   281  		underlay: underlay,
   282  		direct:   direct,
   283  		ctx:      ctx,
   284  		cancel:   cancel,
   285  	}
   286  	switch strings.ToLower(cfg.Router.DomainStrategy) {
   287  	case "as_is", "as-is", "asis":
   288  		client.domainStrategy = AsIs
   289  	case "ip_if_non_match", "ip-if-non-match", "ipifnonmatch":
   290  		client.domainStrategy = IPIfNonMatch
   291  	case "ip_on_demand", "ip-on-demand", "ipondemand":
   292  		client.domainStrategy = IPOnDemand
   293  	default:
   294  		return nil, common.NewError("unknown strategy: " + cfg.Router.DomainStrategy)
   295  	}
   296  
   297  	switch strings.ToLower(cfg.Router.DefaultPolicy) {
   298  	case "proxy":
   299  		client.defaultPolicy = Proxy
   300  	case "bypass":
   301  		client.defaultPolicy = Bypass
   302  	case "block":
   303  		client.defaultPolicy = Block
   304  	default:
   305  		return nil, common.NewError("unknown strategy: " + cfg.Router.DomainStrategy)
   306  	}
   307  
   308  	runtime.ReadMemStats(&m1)
   309  
   310  	geodataLoader := geodata.NewGeodataLoader()
   311  
   312  	ipCode := loadCode(cfg, "geoip:")
   313  	for _, c := range ipCode {
   314  		code := c.code
   315  		cidrs, err := geodataLoader.LoadIP(cfg.Router.GeoIPFilename, code)
   316  		if err != nil {
   317  			log.Error(err)
   318  		} else {
   319  			log.Infof("geoip:%s loaded", code)
   320  			client.cidrs[c.strategy] = append(client.cidrs[c.strategy], cidrs...)
   321  		}
   322  	}
   323  
   324  	runtime.ReadMemStats(&m2)
   325  
   326  	siteCode := loadCode(cfg, "geosite:")
   327  	for _, c := range siteCode {
   328  		code := c.code
   329  		attrWanted := ""
   330  		// Test if user wants domains that have an attribute
   331  		if attrIdx := strings.Index(code, "@"); attrIdx > 0 {
   332  			if !strings.HasSuffix(code, "@") {
   333  				code = c.code[:attrIdx]
   334  				attrWanted = c.code[attrIdx+1:]
   335  			} else { // "geosite:google@" is invalid
   336  				log.Warnf("geosite:%s invalid", code)
   337  				continue
   338  			}
   339  		} else if attrIdx == 0 { // "geosite:@cn" is invalid
   340  			log.Warnf("geosite:%s invalid", code)
   341  			continue
   342  		}
   343  
   344  		domainList, err := geodataLoader.LoadSite(cfg.Router.GeoSiteFilename, code)
   345  		if err != nil {
   346  			log.Error(err)
   347  		} else {
   348  			found := false
   349  			if attrWanted != "" {
   350  				for _, domain := range domainList {
   351  					for _, attr := range domain.GetAttribute() {
   352  						if strings.EqualFold(attrWanted, attr.GetKey()) {
   353  							client.domains[c.strategy] = append(client.domains[c.strategy], domain)
   354  							found = true
   355  						}
   356  					}
   357  				}
   358  			} else {
   359  				client.domains[c.strategy] = append(client.domains[c.strategy], domainList...)
   360  				found = true
   361  			}
   362  			if found {
   363  				log.Infof("geosite:%s loaded", c.code)
   364  			} else {
   365  				log.Errorf("geosite:%s not found", c.code)
   366  			}
   367  		}
   368  	}
   369  
   370  	runtime.ReadMemStats(&m3)
   371  
   372  	domainInfo := loadCode(cfg, "domain:")
   373  	for _, info := range domainInfo {
   374  		client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{
   375  			Type:      v2router.Domain_Domain,
   376  			Value:     strings.ToLower(info.code),
   377  			Attribute: nil,
   378  		})
   379  	}
   380  
   381  	keywordInfo := loadCode(cfg, "keyword:")
   382  	for _, info := range keywordInfo {
   383  		client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{
   384  			Type:      v2router.Domain_Plain,
   385  			Value:     strings.ToLower(info.code),
   386  			Attribute: nil,
   387  		})
   388  	}
   389  
   390  	regexInfo := loadCode(cfg, "regex:")
   391  	for _, info := range regexInfo {
   392  		if _, err := regexp.Compile(info.code); err != nil {
   393  			return nil, common.NewError("invalid regular expression: " + info.code).Base(err)
   394  		}
   395  		client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{
   396  			Type:      v2router.Domain_Regex,
   397  			Value:     info.code,
   398  			Attribute: nil,
   399  		})
   400  	}
   401  
   402  	// Just for compatibility with V2Ray rule type `regexp`
   403  	regexpInfo := loadCode(cfg, "regexp:")
   404  	for _, info := range regexpInfo {
   405  		if _, err := regexp.Compile(info.code); err != nil {
   406  			return nil, common.NewError("invalid regular expression: " + info.code).Base(err)
   407  		}
   408  		client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{
   409  			Type:      v2router.Domain_Regex,
   410  			Value:     info.code,
   411  			Attribute: nil,
   412  		})
   413  	}
   414  
   415  	fullInfo := loadCode(cfg, "full:")
   416  	for _, info := range fullInfo {
   417  		client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{
   418  			Type:      v2router.Domain_Full,
   419  			Value:     strings.ToLower(info.code),
   420  			Attribute: nil,
   421  		})
   422  	}
   423  
   424  	cidrInfo := loadCode(cfg, "cidr:")
   425  	for _, info := range cidrInfo {
   426  		tmp := strings.Split(info.code, "/")
   427  		if len(tmp) != 2 {
   428  			return nil, common.NewError("invalid cidr: " + info.code)
   429  		}
   430  		ip := net.ParseIP(tmp[0])
   431  		if ip == nil {
   432  			return nil, common.NewError("invalid cidr ip: " + info.code)
   433  		}
   434  		prefix, err := strconv.ParseInt(tmp[1], 10, 32)
   435  		if err != nil {
   436  			return nil, common.NewError("invalid prefix").Base(err)
   437  		}
   438  		client.cidrs[info.strategy] = append(client.cidrs[info.strategy], &v2router.CIDR{
   439  			Ip:     ip,
   440  			Prefix: uint32(prefix),
   441  		})
   442  	}
   443  
   444  	log.Info("router client created")
   445  
   446  	runtime.ReadMemStats(&m4)
   447  
   448  	log.Debugf("GeoIP rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m2.Alloc-m1.Alloc), common.HumanFriendlyTraffic(m2.TotalAlloc-m1.TotalAlloc))
   449  	log.Debugf("GeoSite rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m3.Alloc-m2.Alloc), common.HumanFriendlyTraffic(m3.TotalAlloc-m2.TotalAlloc))
   450  	log.Debugf("Plaintext rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m4.Alloc-m3.Alloc), common.HumanFriendlyTraffic(m4.TotalAlloc-m3.TotalAlloc))
   451  	log.Debugf("Total(router) -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m4.Alloc-m1.Alloc), common.HumanFriendlyTraffic(m4.TotalAlloc-m1.TotalAlloc))
   452  
   453  	return client, nil
   454  }