github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/infra/conf/router.go (about)

     1  package conf
     2  
     3  import (
     4  	"encoding/json"
     5  	"runtime"
     6  	"strconv"
     7  	"strings"
     8  
     9  	"github.com/xmplusdev/xmcore/app/router"
    10  	"github.com/xmplusdev/xmcore/common/net"
    11  	"github.com/xmplusdev/xmcore/common/platform/filesystem"
    12  	"github.com/xmplusdev/xmcore/common/serial"
    13  	"google.golang.org/protobuf/proto"
    14  )
    15  
    16  type RouterRulesConfig struct {
    17  	RuleList       []json.RawMessage `json:"rules"`
    18  	DomainStrategy string            `json:"domainStrategy"`
    19  }
    20  
    21  // StrategyConfig represents a strategy config
    22  type StrategyConfig struct {
    23  	Type     string           `json:"type"`
    24  	Settings *json.RawMessage `json:"settings"`
    25  }
    26  
    27  type BalancingRule struct {
    28  	Tag         string         `json:"tag"`
    29  	Selectors   StringList     `json:"selector"`
    30  	Strategy    StrategyConfig `json:"strategy"`
    31  	FallbackTag string         `json:"fallbackTag"`
    32  }
    33  
    34  // Build builds the balancing rule
    35  func (r *BalancingRule) Build() (*router.BalancingRule, error) {
    36  	if r.Tag == "" {
    37  		return nil, newError("empty balancer tag")
    38  	}
    39  	if len(r.Selectors) == 0 {
    40  		return nil, newError("empty selector list")
    41  	}
    42  
    43  	r.Strategy.Type = strings.ToLower(r.Strategy.Type)
    44  	switch r.Strategy.Type {
    45  	case "":
    46  		r.Strategy.Type = strategyRandom
    47  	case strategyRandom, strategyLeastLoad, strategyLeastPing, strategyRoundRobin:
    48  	default:
    49  		return nil, newError("unknown balancing strategy: " + r.Strategy.Type)
    50  	}
    51  
    52  	settings := []byte("{}")
    53  	if r.Strategy.Settings != nil {
    54  		settings = ([]byte)(*r.Strategy.Settings)
    55  	}
    56  	rawConfig, err := strategyConfigLoader.LoadWithID(settings, r.Strategy.Type)
    57  	if err != nil {
    58  		return nil, newError("failed to parse to strategy config.").Base(err)
    59  	}
    60  	var ts proto.Message
    61  	if builder, ok := rawConfig.(Buildable); ok {
    62  		ts, err = builder.Build()
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  	}
    67  
    68  	return &router.BalancingRule{
    69  		Strategy:         r.Strategy.Type,
    70  		StrategySettings: serial.ToTypedMessage(ts),
    71  		FallbackTag:      r.FallbackTag,
    72  		OutboundSelector: r.Selectors,
    73  		Tag:              r.Tag,
    74  	}, nil
    75  }
    76  
    77  type RouterConfig struct {
    78  	Settings       *RouterRulesConfig `json:"settings"` // Deprecated
    79  	RuleList       []json.RawMessage  `json:"rules"`
    80  	DomainStrategy *string            `json:"domainStrategy"`
    81  	Balancers      []*BalancingRule   `json:"balancers"`
    82  
    83  	DomainMatcher string `json:"domainMatcher"`
    84  }
    85  
    86  func (c *RouterConfig) getDomainStrategy() router.Config_DomainStrategy {
    87  	ds := ""
    88  	if c.DomainStrategy != nil {
    89  		ds = *c.DomainStrategy
    90  	} else if c.Settings != nil {
    91  		ds = c.Settings.DomainStrategy
    92  	}
    93  
    94  	switch strings.ToLower(ds) {
    95  	case "alwaysip":
    96  		return router.Config_UseIp
    97  	case "ipifnonmatch":
    98  		return router.Config_IpIfNonMatch
    99  	case "ipondemand":
   100  		return router.Config_IpOnDemand
   101  	default:
   102  		return router.Config_AsIs
   103  	}
   104  }
   105  
   106  func (c *RouterConfig) Build() (*router.Config, error) {
   107  	config := new(router.Config)
   108  	config.DomainStrategy = c.getDomainStrategy()
   109  
   110  	var rawRuleList []json.RawMessage
   111  	if c != nil {
   112  		rawRuleList = c.RuleList
   113  		if c.Settings != nil {
   114  			c.RuleList = append(c.RuleList, c.Settings.RuleList...)
   115  			rawRuleList = c.RuleList
   116  		}
   117  	}
   118  
   119  	for _, rawRule := range rawRuleList {
   120  		rule, err := ParseRule(rawRule)
   121  		if err != nil {
   122  			return nil, err
   123  		}
   124  
   125  		if rule.DomainMatcher == "" {
   126  			rule.DomainMatcher = c.DomainMatcher
   127  		}
   128  
   129  		config.Rule = append(config.Rule, rule)
   130  	}
   131  	for _, rawBalancer := range c.Balancers {
   132  		balancer, err := rawBalancer.Build()
   133  		if err != nil {
   134  			return nil, err
   135  		}
   136  		config.BalancingRule = append(config.BalancingRule, balancer)
   137  	}
   138  	return config, nil
   139  }
   140  
   141  type RouterRule struct {
   142  	RuleTag     string `json:"ruleTag"`
   143  	Type        string `json:"type"`
   144  	OutboundTag string `json:"outboundTag"`
   145  	BalancerTag string `json:"balancerTag"`
   146  
   147  	DomainMatcher string `json:"domainMatcher"`
   148  }
   149  
   150  func ParseIP(s string) (*router.CIDR, error) {
   151  	var addr, mask string
   152  	i := strings.Index(s, "/")
   153  	if i < 0 {
   154  		addr = s
   155  	} else {
   156  		addr = s[:i]
   157  		mask = s[i+1:]
   158  	}
   159  	ip := net.ParseAddress(addr)
   160  	switch ip.Family() {
   161  	case net.AddressFamilyIPv4:
   162  		bits := uint32(32)
   163  		if len(mask) > 0 {
   164  			bits64, err := strconv.ParseUint(mask, 10, 32)
   165  			if err != nil {
   166  				return nil, newError("invalid network mask for router: ", mask).Base(err)
   167  			}
   168  			bits = uint32(bits64)
   169  		}
   170  		if bits > 32 {
   171  			return nil, newError("invalid network mask for router: ", bits)
   172  		}
   173  		return &router.CIDR{
   174  			Ip:     []byte(ip.IP()),
   175  			Prefix: bits,
   176  		}, nil
   177  	case net.AddressFamilyIPv6:
   178  		bits := uint32(128)
   179  		if len(mask) > 0 {
   180  			bits64, err := strconv.ParseUint(mask, 10, 32)
   181  			if err != nil {
   182  				return nil, newError("invalid network mask for router: ", mask).Base(err)
   183  			}
   184  			bits = uint32(bits64)
   185  		}
   186  		if bits > 128 {
   187  			return nil, newError("invalid network mask for router: ", bits)
   188  		}
   189  		return &router.CIDR{
   190  			Ip:     []byte(ip.IP()),
   191  			Prefix: bits,
   192  		}, nil
   193  	default:
   194  		return nil, newError("unsupported address for router: ", s)
   195  	}
   196  }
   197  
   198  func loadGeoIP(code string) ([]*router.CIDR, error) {
   199  	return loadIP("geoip.dat", code)
   200  }
   201  
   202  var (
   203  	FileCache = make(map[string][]byte)
   204  	IPCache   = make(map[string]*router.GeoIP)
   205  	SiteCache = make(map[string]*router.GeoSite)
   206  )
   207  
   208  func loadFile(file string) ([]byte, error) {
   209  	if FileCache[file] == nil {
   210  		bs, err := filesystem.ReadAsset(file)
   211  		if err != nil {
   212  			return nil, newError("failed to open file: ", file).Base(err)
   213  		}
   214  		if len(bs) == 0 {
   215  			return nil, newError("empty file: ", file)
   216  		}
   217  		// Do not cache file, may save RAM when there
   218  		// are many files, but consume CPU each time.
   219  		return bs, nil
   220  		FileCache[file] = bs
   221  	}
   222  	return FileCache[file], nil
   223  }
   224  
   225  func loadIP(file, code string) ([]*router.CIDR, error) {
   226  	index := file + ":" + code
   227  	if IPCache[index] == nil {
   228  		bs, err := loadFile(file)
   229  		if err != nil {
   230  			return nil, newError("failed to load file: ", file).Base(err)
   231  		}
   232  		bs = find(bs, []byte(code))
   233  		if bs == nil {
   234  			return nil, newError("code not found in ", file, ": ", code)
   235  		}
   236  		var geoip router.GeoIP
   237  		if err := proto.Unmarshal(bs, &geoip); err != nil {
   238  			return nil, newError("error unmarshal IP in ", file, ": ", code).Base(err)
   239  		}
   240  		defer runtime.GC()     // or debug.FreeOSMemory()
   241  		return geoip.Cidr, nil // do not cache geoip
   242  		IPCache[index] = &geoip
   243  	}
   244  	return IPCache[index].Cidr, nil
   245  }
   246  
   247  func loadSite(file, code string) ([]*router.Domain, error) {
   248  	index := file + ":" + code
   249  	if SiteCache[index] == nil {
   250  		bs, err := loadFile(file)
   251  		if err != nil {
   252  			return nil, newError("failed to load file: ", file).Base(err)
   253  		}
   254  		bs = find(bs, []byte(code))
   255  		if bs == nil {
   256  			return nil, newError("list not found in ", file, ": ", code)
   257  		}
   258  		var geosite router.GeoSite
   259  		if err := proto.Unmarshal(bs, &geosite); err != nil {
   260  			return nil, newError("error unmarshal Site in ", file, ": ", code).Base(err)
   261  		}
   262  		defer runtime.GC()         // or debug.FreeOSMemory()
   263  		return geosite.Domain, nil // do not cache geosite
   264  		SiteCache[index] = &geosite
   265  	}
   266  	return SiteCache[index].Domain, nil
   267  }
   268  
   269  func DecodeVarint(buf []byte) (x uint64, n int) {
   270  	for shift := uint(0); shift < 64; shift += 7 {
   271  		if n >= len(buf) {
   272  			return 0, 0
   273  		}
   274  		b := uint64(buf[n])
   275  		n++
   276  		x |= (b & 0x7F) << shift
   277  		if (b & 0x80) == 0 {
   278  			return x, n
   279  		}
   280  	}
   281  
   282  	// The number is too large to represent in a 64-bit value.
   283  	return 0, 0
   284  }
   285  
   286  func find(data, code []byte) []byte {
   287  	codeL := len(code)
   288  	if codeL == 0 {
   289  		return nil
   290  	}
   291  	for {
   292  		dataL := len(data)
   293  		if dataL < 2 {
   294  			return nil
   295  		}
   296  		x, y := DecodeVarint(data[1:])
   297  		if x == 0 && y == 0 {
   298  			return nil
   299  		}
   300  		headL, bodyL := 1+y, int(x)
   301  		dataL -= headL
   302  		if dataL < bodyL {
   303  			return nil
   304  		}
   305  		data = data[headL:]
   306  		if int(data[1]) == codeL {
   307  			for i := 0; i < codeL && data[2+i] == code[i]; i++ {
   308  				if i+1 == codeL {
   309  					return data[:bodyL]
   310  				}
   311  			}
   312  		}
   313  		if dataL == bodyL {
   314  			return nil
   315  		}
   316  		data = data[bodyL:]
   317  	}
   318  }
   319  
   320  type AttributeMatcher interface {
   321  	Match(*router.Domain) bool
   322  }
   323  
   324  type BooleanMatcher string
   325  
   326  func (m BooleanMatcher) Match(domain *router.Domain) bool {
   327  	for _, attr := range domain.Attribute {
   328  		if attr.Key == string(m) {
   329  			return true
   330  		}
   331  	}
   332  	return false
   333  }
   334  
   335  type AttributeList struct {
   336  	matcher []AttributeMatcher
   337  }
   338  
   339  func (al *AttributeList) Match(domain *router.Domain) bool {
   340  	for _, matcher := range al.matcher {
   341  		if !matcher.Match(domain) {
   342  			return false
   343  		}
   344  	}
   345  	return true
   346  }
   347  
   348  func (al *AttributeList) IsEmpty() bool {
   349  	return len(al.matcher) == 0
   350  }
   351  
   352  func parseAttrs(attrs []string) *AttributeList {
   353  	al := new(AttributeList)
   354  	for _, attr := range attrs {
   355  		lc := strings.ToLower(attr)
   356  		al.matcher = append(al.matcher, BooleanMatcher(lc))
   357  	}
   358  	return al
   359  }
   360  
   361  func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, error) {
   362  	parts := strings.Split(siteWithAttr, "@")
   363  	if len(parts) == 0 {
   364  		return nil, newError("empty site")
   365  	}
   366  	country := strings.ToUpper(parts[0])
   367  	attrs := parseAttrs(parts[1:])
   368  	domains, err := loadSite(file, country)
   369  	if err != nil {
   370  		return nil, err
   371  	}
   372  
   373  	if attrs.IsEmpty() {
   374  		return domains, nil
   375  	}
   376  
   377  	filteredDomains := make([]*router.Domain, 0, len(domains))
   378  	for _, domain := range domains {
   379  		if attrs.Match(domain) {
   380  			filteredDomains = append(filteredDomains, domain)
   381  		}
   382  	}
   383  
   384  	return filteredDomains, nil
   385  }
   386  
   387  func parseDomainRule(domain string) ([]*router.Domain, error) {
   388  	if strings.HasPrefix(domain, "geosite:") {
   389  		country := strings.ToUpper(domain[8:])
   390  		domains, err := loadGeositeWithAttr("geosite.dat", country)
   391  		if err != nil {
   392  			return nil, newError("failed to load geosite: ", country).Base(err)
   393  		}
   394  		return domains, nil
   395  	}
   396  	isExtDatFile := 0
   397  	{
   398  		const prefix = "ext:"
   399  		if strings.HasPrefix(domain, prefix) {
   400  			isExtDatFile = len(prefix)
   401  		}
   402  		const prefixQualified = "ext-domain:"
   403  		if strings.HasPrefix(domain, prefixQualified) {
   404  			isExtDatFile = len(prefixQualified)
   405  		}
   406  	}
   407  	if isExtDatFile != 0 {
   408  		kv := strings.Split(domain[isExtDatFile:], ":")
   409  		if len(kv) != 2 {
   410  			return nil, newError("invalid external resource: ", domain)
   411  		}
   412  		filename := kv[0]
   413  		country := kv[1]
   414  		domains, err := loadGeositeWithAttr(filename, country)
   415  		if err != nil {
   416  			return nil, newError("failed to load external sites: ", country, " from ", filename).Base(err)
   417  		}
   418  		return domains, nil
   419  	}
   420  
   421  	domainRule := new(router.Domain)
   422  	switch {
   423  	case strings.HasPrefix(domain, "regexp:"):
   424  		domainRule.Type = router.Domain_Regex
   425  		domainRule.Value = domain[7:]
   426  
   427  	case strings.HasPrefix(domain, "domain:"):
   428  		domainRule.Type = router.Domain_Domain
   429  		domainRule.Value = domain[7:]
   430  
   431  	case strings.HasPrefix(domain, "full:"):
   432  		domainRule.Type = router.Domain_Full
   433  		domainRule.Value = domain[5:]
   434  
   435  	case strings.HasPrefix(domain, "keyword:"):
   436  		domainRule.Type = router.Domain_Plain
   437  		domainRule.Value = domain[8:]
   438  
   439  	case strings.HasPrefix(domain, "dotless:"):
   440  		domainRule.Type = router.Domain_Regex
   441  		switch substr := domain[8:]; {
   442  		case substr == "":
   443  			domainRule.Value = "^[^.]*$"
   444  		case !strings.Contains(substr, "."):
   445  			domainRule.Value = "^[^.]*" + substr + "[^.]*$"
   446  		default:
   447  			return nil, newError("substr in dotless rule should not contain a dot: ", substr)
   448  		}
   449  
   450  	default:
   451  		domainRule.Type = router.Domain_Plain
   452  		domainRule.Value = domain
   453  	}
   454  	return []*router.Domain{domainRule}, nil
   455  }
   456  
   457  func ToCidrList(ips StringList) ([]*router.GeoIP, error) {
   458  	var geoipList []*router.GeoIP
   459  	var customCidrs []*router.CIDR
   460  
   461  	for _, ip := range ips {
   462  		if strings.HasPrefix(ip, "geoip:") {
   463  			country := ip[6:]
   464  			isReverseMatch := false
   465  			if strings.HasPrefix(ip, "geoip:!") {
   466  				country = ip[7:]
   467  				isReverseMatch = true
   468  			}
   469  			if len(country) == 0 {
   470  				return nil, newError("empty country name in rule")
   471  			}
   472  			geoip, err := loadGeoIP(strings.ToUpper(country))
   473  			if err != nil {
   474  				return nil, newError("failed to load GeoIP: ", country).Base(err)
   475  			}
   476  
   477  			geoipList = append(geoipList, &router.GeoIP{
   478  				CountryCode:  strings.ToUpper(country),
   479  				Cidr:         geoip,
   480  				ReverseMatch: isReverseMatch,
   481  			})
   482  			continue
   483  		}
   484  		isExtDatFile := 0
   485  		{
   486  			const prefix = "ext:"
   487  			if strings.HasPrefix(ip, prefix) {
   488  				isExtDatFile = len(prefix)
   489  			}
   490  			const prefixQualified = "ext-ip:"
   491  			if strings.HasPrefix(ip, prefixQualified) {
   492  				isExtDatFile = len(prefixQualified)
   493  			}
   494  		}
   495  		if isExtDatFile != 0 {
   496  			kv := strings.Split(ip[isExtDatFile:], ":")
   497  			if len(kv) != 2 {
   498  				return nil, newError("invalid external resource: ", ip)
   499  			}
   500  
   501  			filename := kv[0]
   502  			country := kv[1]
   503  			if len(filename) == 0 || len(country) == 0 {
   504  				return nil, newError("empty filename or empty country in rule")
   505  			}
   506  
   507  			isReverseMatch := false
   508  			if strings.HasPrefix(country, "!") {
   509  				country = country[1:]
   510  				isReverseMatch = true
   511  			}
   512  			geoip, err := loadIP(filename, strings.ToUpper(country))
   513  			if err != nil {
   514  				return nil, newError("failed to load IPs: ", country, " from ", filename).Base(err)
   515  			}
   516  
   517  			geoipList = append(geoipList, &router.GeoIP{
   518  				CountryCode:  strings.ToUpper(filename + "_" + country),
   519  				Cidr:         geoip,
   520  				ReverseMatch: isReverseMatch,
   521  			})
   522  
   523  			continue
   524  		}
   525  
   526  		ipRule, err := ParseIP(ip)
   527  		if err != nil {
   528  			return nil, newError("invalid IP: ", ip).Base(err)
   529  		}
   530  		customCidrs = append(customCidrs, ipRule)
   531  	}
   532  
   533  	if len(customCidrs) > 0 {
   534  		geoipList = append(geoipList, &router.GeoIP{
   535  			Cidr: customCidrs,
   536  		})
   537  	}
   538  
   539  	return geoipList, nil
   540  }
   541  
   542  func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
   543  	type RawFieldRule struct {
   544  		RouterRule
   545  		Domain     *StringList       `json:"domain"`
   546  		Domains    *StringList       `json:"domains"`
   547  		IP         *StringList       `json:"ip"`
   548  		Port       *PortList         `json:"port"`
   549  		Network    *NetworkList      `json:"network"`
   550  		SourceIP   *StringList       `json:"source"`
   551  		SourcePort *PortList         `json:"sourcePort"`
   552  		User       *StringList       `json:"user"`
   553  		InboundTag *StringList       `json:"inboundTag"`
   554  		Protocols  *StringList       `json:"protocol"`
   555  		Attributes map[string]string `json:"attrs"`
   556  	}
   557  	rawFieldRule := new(RawFieldRule)
   558  	err := json.Unmarshal(msg, rawFieldRule)
   559  	if err != nil {
   560  		return nil, err
   561  	}
   562  
   563  	rule := new(router.RoutingRule)
   564  	rule.RuleTag = rawFieldRule.RuleTag
   565  	switch {
   566  	case len(rawFieldRule.OutboundTag) > 0:
   567  		rule.TargetTag = &router.RoutingRule_Tag{
   568  			Tag: rawFieldRule.OutboundTag,
   569  		}
   570  	case len(rawFieldRule.BalancerTag) > 0:
   571  		rule.TargetTag = &router.RoutingRule_BalancingTag{
   572  			BalancingTag: rawFieldRule.BalancerTag,
   573  		}
   574  	default:
   575  		return nil, newError("neither outboundTag nor balancerTag is specified in routing rule")
   576  	}
   577  
   578  	if rawFieldRule.DomainMatcher != "" {
   579  		rule.DomainMatcher = rawFieldRule.DomainMatcher
   580  	}
   581  
   582  	if rawFieldRule.Domain != nil {
   583  		for _, domain := range *rawFieldRule.Domain {
   584  			rules, err := parseDomainRule(domain)
   585  			if err != nil {
   586  				return nil, newError("failed to parse domain rule: ", domain).Base(err)
   587  			}
   588  			rule.Domain = append(rule.Domain, rules...)
   589  		}
   590  	}
   591  
   592  	if rawFieldRule.Domains != nil {
   593  		for _, domain := range *rawFieldRule.Domains {
   594  			rules, err := parseDomainRule(domain)
   595  			if err != nil {
   596  				return nil, newError("failed to parse domain rule: ", domain).Base(err)
   597  			}
   598  			rule.Domain = append(rule.Domain, rules...)
   599  		}
   600  	}
   601  
   602  	if rawFieldRule.IP != nil {
   603  		geoipList, err := ToCidrList(*rawFieldRule.IP)
   604  		if err != nil {
   605  			return nil, err
   606  		}
   607  		rule.Geoip = geoipList
   608  	}
   609  
   610  	if rawFieldRule.Port != nil {
   611  		rule.PortList = rawFieldRule.Port.Build()
   612  	}
   613  
   614  	if rawFieldRule.Network != nil {
   615  		rule.Networks = rawFieldRule.Network.Build()
   616  	}
   617  
   618  	if rawFieldRule.SourceIP != nil {
   619  		geoipList, err := ToCidrList(*rawFieldRule.SourceIP)
   620  		if err != nil {
   621  			return nil, err
   622  		}
   623  		rule.SourceGeoip = geoipList
   624  	}
   625  
   626  	if rawFieldRule.SourcePort != nil {
   627  		rule.SourcePortList = rawFieldRule.SourcePort.Build()
   628  	}
   629  
   630  	if rawFieldRule.User != nil {
   631  		for _, s := range *rawFieldRule.User {
   632  			rule.UserEmail = append(rule.UserEmail, s)
   633  		}
   634  	}
   635  
   636  	if rawFieldRule.InboundTag != nil {
   637  		for _, s := range *rawFieldRule.InboundTag {
   638  			rule.InboundTag = append(rule.InboundTag, s)
   639  		}
   640  	}
   641  
   642  	if rawFieldRule.Protocols != nil {
   643  		for _, s := range *rawFieldRule.Protocols {
   644  			rule.Protocol = append(rule.Protocol, s)
   645  		}
   646  	}
   647  
   648  	if len(rawFieldRule.Attributes) > 0 {
   649  		rule.Attributes = rawFieldRule.Attributes
   650  	}
   651  
   652  	return rule, nil
   653  }
   654  
   655  func ParseRule(msg json.RawMessage) (*router.RoutingRule, error) {
   656  	rawRule := new(RouterRule)
   657  	err := json.Unmarshal(msg, rawRule)
   658  	if err != nil {
   659  		return nil, newError("invalid router rule").Base(err)
   660  	}
   661  	if rawRule.Type == "" || strings.EqualFold(rawRule.Type, "field") {
   662  		fieldrule, err := parseFieldRule(msg)
   663  		if err != nil {
   664  			return nil, newError("invalid field rule").Base(err)
   665  		}
   666  		return fieldrule, nil
   667  	}
   668  	return nil, newError("unknown router rule type: ", rawRule.Type)
   669  }