github.com/database64128/shadowsocks-go@v1.7.0/router/route.go (about)

     1  package router
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/netip"
     7  
     8  	"github.com/database64128/shadowsocks-go/bitset"
     9  	"github.com/database64128/shadowsocks-go/conn"
    10  	"github.com/database64128/shadowsocks-go/dns"
    11  	"github.com/database64128/shadowsocks-go/domainset"
    12  	"github.com/database64128/shadowsocks-go/portset"
    13  	"github.com/database64128/shadowsocks-go/slices"
    14  	"github.com/database64128/shadowsocks-go/zerocopy"
    15  	"github.com/oschwald/geoip2-golang"
    16  	"go.uber.org/zap"
    17  	"go4.org/netipx"
    18  )
    19  
    20  // ErrRejected is a special error that indicates the request is rejected.
    21  var ErrRejected = errors.New("rejected")
    22  
    23  var errPointlessPortCriteria = errors.New("matching all ports is equivalent to not having any port filtering rules")
    24  
    25  // RouteConfig is a routing rule.
    26  type RouteConfig struct {
    27  	// Name of this route. Used in logs to identify matched routes.
    28  	Name string `json:"name"`
    29  
    30  	// Apply this route to "tcp" or "udp" only. If empty, match all requests.
    31  	Network string `json:"network"`
    32  
    33  	// Route matched requests to this client. Must not be empty.
    34  	Client string `json:"client"`
    35  
    36  	// When matching a domain target to IP prefixes, use this resolver to resolve the domain name.
    37  	// If unspecified, use all resolvers by order.
    38  	Resolver string `json:"resolver"`
    39  
    40  	// Match requests from these servers. If empty, match all requests.
    41  	FromServers []string `json:"fromServers"`
    42  
    43  	// Match requests from these users. If empty, match all requests.
    44  	FromUsers []string `json:"fromUsers"`
    45  
    46  	// Match requests from these ports. If empty, match all requests.
    47  	FromPorts []uint16 `json:"fromPorts"`
    48  
    49  	// Match requests from these ports and port ranges. If empty, match all requests.
    50  	FromPortRanges string `json:"fromPortRanges"`
    51  
    52  	// Match requests from IP addresses in these prefixes. If empty, match all requests.
    53  	FromPrefixes []netip.Prefix `json:"fromPrefixes"`
    54  
    55  	// Match requests from IP addresses in these prefix sets. If empty, match all requests.
    56  	FromPrefixSets []string `json:"fromPrefixSets"`
    57  
    58  	// Match requests from IP addresses in these countries. If empty, match all requests.
    59  	FromGeoIPCountries []string `json:"fromGeoIPCountries"`
    60  
    61  	// Match requests to these ports. If empty, match all requests.
    62  	ToPorts []uint16 `json:"toPorts"`
    63  
    64  	// Match requests to these ports and port ranges. If empty, match all requests.
    65  	ToPortRanges string `json:"toPortRanges"`
    66  
    67  	// Match requests to these domain targets. If empty, match all requests.
    68  	ToDomains []string `json:"toDomains"`
    69  
    70  	// Match requests to domains in these domain sets. If empty, match all requests.
    71  	ToDomainSets []string `json:"toDomainSets"`
    72  
    73  	// Require the matched domain target to resolve to IP addresses in these prefixes.
    74  	ToMatchedDomainExpectedPrefixes []netip.Prefix `json:"toMatchedDomainExpectedPrefixes"`
    75  
    76  	// Require the matched domain target to resolve to IP addresses in these prefix sets.
    77  	ToMatchedDomainExpectedPrefixSets []string `json:"toMatchedDomainExpectedPrefixSets"`
    78  
    79  	// Require the matched domain target to resolve to IP addresses in these countries.
    80  	ToMatchedDomainExpectedGeoIPCountries []string `json:"toMatchedDomainExpectedGeoIPCountries"`
    81  
    82  	// Match requests to IP addresses in these prefixes. If empty, match all requests.
    83  	ToPrefixes []netip.Prefix `json:"toPrefixes"`
    84  
    85  	// Match requests to IP addresses in these prefix sets. If empty, match all requests.
    86  	ToPrefixSets []string `json:"toPrefixSets"`
    87  
    88  	// Match requests to IP addresses in these countries. If empty, match all requests.
    89  	ToGeoIPCountries []string `json:"toGeoIPCountries"`
    90  
    91  	// Do not resolve destination domains to match IP rules.
    92  	DisableNameResolutionForIPRules bool `json:"disableNameResolutionForIPRules"`
    93  
    94  	// Invert source server matching logic. Match requests from all servers except those in FromServers.
    95  	InvertFromServers bool `json:"invertFromServers"`
    96  
    97  	// Invert source user matching logic. Match requests from all users except those in FromUsers.
    98  	InvertFromUsers bool `json:"invertFromUsers"`
    99  
   100  	// Invert source IP prefix matching logic. Match requests from all IP prefixes except those in FromPrefixes or FromPrefixSets.
   101  	InvertFromPrefixes bool `json:"invertFromPrefixes"`
   102  
   103  	// Invert source GeoIP country matching logic. Match requests from all countries except those in FromGeoIPCountries.
   104  	InvertFromGeoIPCountries bool `json:"invertFromGeoIPCountries"`
   105  
   106  	// Invert source port matching logic. Match requests from all ports except those in FromPorts.
   107  	InvertFromPorts bool `json:"invertFromPorts"`
   108  
   109  	// Invert destination domain matching logic. Match requests to all domains except those in ToDomains or ToDomainSets.
   110  	InvertToDomains bool `json:"invertToDomains"`
   111  
   112  	// Invert destination domain expected prefix matching logic. Match requests to all domains except those whose resolved IP addresses are in ToMatchedDomainExpectedPrefixes or ToMatchedDomainExpectedPrefixSets.
   113  	InvertToMatchedDomainExpectedPrefixes bool `json:"invertToMatchedDomainExpectedPrefixes"`
   114  
   115  	// Invert destination domain expected GeoIP country matching logic. Match requests to all domains except those whose resolved IP addresses are in ToMatchedDomainExpectedGeoIPCountries.
   116  	InvertToMatchedDomainExpectedGeoIPCountries bool `json:"invertToMatchedDomainExpectedGeoIPCountries"`
   117  
   118  	// Invert destination IP prefix matching logic. Match requests to all IP prefixes except those in ToPrefixes or ToPrefixSets.
   119  	InvertToPrefixes bool `json:"invertToPrefixes"`
   120  
   121  	// Invert destination GeoIP country matching logic. Match requests to all countries except those in ToGeoIPCountries.
   122  	InvertToGeoIPCountries bool `json:"invertToGeoIPCountries"`
   123  
   124  	// Invert destination port matching logic. Match requests to all ports except those in ToPorts.
   125  	InvertToPorts bool `json:"invertToPorts"`
   126  }
   127  
   128  // Route creates a route from the RouteConfig.
   129  func (rc *RouteConfig) Route(geoip *geoip2.Reader, logger *zap.Logger, resolvers []*dns.Resolver, resolverMap map[string]*dns.Resolver, tcpClientMap map[string]zerocopy.TCPClient, udpClientMap map[string]zerocopy.UDPClient, serverIndexByName map[string]int, domainSetMap map[string]domainset.DomainSet, prefixSetMap map[string]*netipx.IPSet) (Route, error) {
   130  	// Bad name.
   131  	switch rc.Name {
   132  	case "", "default":
   133  		return Route{}, errors.New("route name cannot be empty or 'default'")
   134  	}
   135  
   136  	// Has GeoIP criteria but no GeoIP database.
   137  	if geoip == nil && (len(rc.FromGeoIPCountries) > 0 || len(rc.ToGeoIPCountries) > 0 || len(rc.ToMatchedDomainExpectedGeoIPCountries) > 0) {
   138  		return Route{}, errors.New("missing GeoLite2 country database path")
   139  	}
   140  
   141  	// Needs to resolve domain names but has no resolvers.
   142  	if len(resolvers) == 0 &&
   143  		(len(rc.ToMatchedDomainExpectedPrefixes) > 0 ||
   144  			len(rc.ToMatchedDomainExpectedPrefixSets) > 0 ||
   145  			len(rc.ToMatchedDomainExpectedGeoIPCountries) > 0 ||
   146  			!rc.DisableNameResolutionForIPRules &&
   147  				(len(rc.ToPrefixes) > 0 || len(rc.ToPrefixSets) > 0 || len(rc.ToGeoIPCountries) > 0)) {
   148  		return Route{}, errors.New("missing resolvers")
   149  	}
   150  
   151  	// Has resolved IP expectations but no destination domain criteria.
   152  	if len(rc.ToDomains) == 0 && len(rc.ToDomainSets) == 0 &&
   153  		(len(rc.ToMatchedDomainExpectedPrefixes) > 0 ||
   154  			len(rc.ToMatchedDomainExpectedPrefixSets) > 0 ||
   155  			len(rc.ToMatchedDomainExpectedGeoIPCountries) > 0) {
   156  		return Route{}, errors.New("missing destination domain criteria")
   157  	}
   158  
   159  	if rc.Resolver != "" {
   160  		resolver, ok := resolverMap[rc.Resolver]
   161  		if !ok {
   162  			return Route{}, fmt.Errorf("resolver not found: %s", rc.Resolver)
   163  		}
   164  		resolvers = []*dns.Resolver{resolver}
   165  	}
   166  
   167  	route := Route{name: rc.Name}
   168  
   169  	switch rc.Network {
   170  	case "":
   171  	case "tcp":
   172  		route.AddCriterion(NetworkTCPCriterion{}, false)
   173  	case "udp":
   174  		route.AddCriterion(NetworkUDPCriterion{}, false)
   175  	default:
   176  		return Route{}, fmt.Errorf("invalid network: %s", rc.Network)
   177  	}
   178  
   179  	if rc.Client != "reject" {
   180  		switch rc.Network {
   181  		case "", "tcp":
   182  			route.tcpClient = tcpClientMap[rc.Client]
   183  			if route.tcpClient == nil {
   184  				return Route{}, fmt.Errorf("TCP client not found: %s", rc.Client)
   185  			}
   186  		}
   187  
   188  		switch rc.Network {
   189  		case "", "udp":
   190  			route.udpClient = udpClientMap[rc.Client]
   191  			if route.udpClient == nil {
   192  				return Route{}, fmt.Errorf("UDP client not found: %s", rc.Client)
   193  			}
   194  		}
   195  	}
   196  
   197  	if len(rc.FromServers) > 0 {
   198  		sourceServerSet := bitset.NewBitSet(uint(len(serverIndexByName)))
   199  
   200  		for _, server := range rc.FromServers {
   201  			index, ok := serverIndexByName[server]
   202  			if !ok {
   203  				return Route{}, fmt.Errorf("server not found: %s", server)
   204  			}
   205  			sourceServerSet.Set(uint(index))
   206  		}
   207  
   208  		route.AddCriterion(SourceServerCriterion(sourceServerSet), rc.InvertFromServers)
   209  	}
   210  
   211  	if len(rc.FromUsers) > 0 {
   212  		route.AddCriterion(SourceUserCriterion(rc.FromUsers), rc.InvertFromUsers)
   213  	}
   214  
   215  	if len(rc.FromPorts) > 0 || rc.FromPortRanges != "" {
   216  		var portSet portset.PortSet
   217  
   218  		for _, port := range rc.FromPorts {
   219  			if port == 0 {
   220  				return Route{}, fmt.Errorf("bad fromPorts: %w", portset.ErrZeroPort)
   221  			}
   222  			portSet.Add(port)
   223  		}
   224  
   225  		if err := portSet.Parse(rc.FromPortRanges); err != nil {
   226  			return Route{}, fmt.Errorf("failed to parse source port ranges: %w", err)
   227  		}
   228  
   229  		portCount := portSet.Count()
   230  		switch portCount {
   231  		case 0:
   232  			panic("unreachable")
   233  		case 1:
   234  			route.AddCriterion(SourcePortCriterion(portSet.First()), rc.InvertFromPorts)
   235  		case 65535:
   236  			return Route{}, fmt.Errorf("bad source port criteria: %w", errPointlessPortCriteria)
   237  		default:
   238  			portRangeCount := portSet.RangeCount()
   239  			if portRangeCount <= 16 {
   240  				route.AddCriterion(SourcePortRangeSetCriterion(portSet.RangeSet()), rc.InvertFromPorts)
   241  			} else {
   242  				sourcePortSetCriterion := SourcePortSetCriterion(portSet)
   243  				route.AddCriterion(&sourcePortSetCriterion, rc.InvertFromPorts)
   244  			}
   245  		}
   246  	}
   247  
   248  	if len(rc.FromPrefixes) > 0 || len(rc.FromPrefixSets) > 0 || len(rc.FromGeoIPCountries) > 0 {
   249  		var group CriterionGroupOR
   250  
   251  		if len(rc.FromPrefixes) > 0 || len(rc.FromPrefixSets) > 0 {
   252  			var sb netipx.IPSetBuilder
   253  
   254  			for _, prefix := range rc.FromPrefixes {
   255  				sb.AddPrefix(prefix)
   256  			}
   257  
   258  			for _, prefixSet := range rc.FromPrefixSets {
   259  				s, ok := prefixSetMap[prefixSet]
   260  				if !ok {
   261  					return Route{}, fmt.Errorf("prefix set not found: %s", prefixSet)
   262  				}
   263  				sb.AddSet(s)
   264  			}
   265  
   266  			sourceIPSet, err := sb.IPSet()
   267  			if err != nil {
   268  				return Route{}, fmt.Errorf("failed to build sourceIPSet: %w", err)
   269  			}
   270  
   271  			group.AddCriterion((*SourceIPCriterion)(sourceIPSet), rc.InvertFromPrefixes)
   272  		}
   273  
   274  		if len(rc.FromGeoIPCountries) > 0 {
   275  			group.AddCriterion(SourceGeoIPCountryCriterion{
   276  				countries: rc.FromGeoIPCountries,
   277  				geoip:     geoip,
   278  				logger:    logger,
   279  			}, rc.InvertFromGeoIPCountries)
   280  		}
   281  
   282  		route.criteria = group.AppendTo(route.criteria)
   283  	}
   284  
   285  	if len(rc.ToPorts) > 0 || rc.ToPortRanges != "" {
   286  		var portSet portset.PortSet
   287  
   288  		for _, port := range rc.ToPorts {
   289  			if port == 0 {
   290  				return Route{}, fmt.Errorf("bad toPorts: %w", portset.ErrZeroPort)
   291  			}
   292  			portSet.Add(port)
   293  		}
   294  
   295  		if err := portSet.Parse(rc.ToPortRanges); err != nil {
   296  			return Route{}, fmt.Errorf("failed to parse destination port ranges: %w", err)
   297  		}
   298  
   299  		portCount := portSet.Count()
   300  		switch portCount {
   301  		case 0:
   302  			panic("unreachable")
   303  		case 1:
   304  			route.AddCriterion(DestPortCriterion(portSet.First()), rc.InvertToPorts)
   305  		case 65535:
   306  			return Route{}, fmt.Errorf("bad destination port criteria: %w", errPointlessPortCriteria)
   307  		default:
   308  			portRangeCount := portSet.RangeCount()
   309  			if portRangeCount <= 16 {
   310  				route.AddCriterion(DestPortRangeSetCriterion(portSet.RangeSet()), rc.InvertToPorts)
   311  			} else {
   312  				destPortSetCriterion := DestPortSetCriterion(portSet)
   313  				route.AddCriterion(&destPortSetCriterion, rc.InvertToPorts)
   314  			}
   315  		}
   316  	}
   317  
   318  	if len(rc.ToDomains) > 0 || len(rc.ToDomainSets) > 0 || len(rc.ToPrefixes) > 0 || len(rc.ToPrefixSets) > 0 || len(rc.ToGeoIPCountries) > 0 {
   319  		var group CriterionGroupOR
   320  
   321  		if len(rc.ToDomains) > 0 || len(rc.ToDomainSets) > 0 {
   322  			var defaultDomainSetCount int
   323  
   324  			if len(rc.ToDomains) > 0 {
   325  				defaultDomainSetCount = 1
   326  			}
   327  
   328  			domainSets := make([]domainset.DomainSet, defaultDomainSetCount+len(rc.ToDomainSets))
   329  
   330  			if defaultDomainSetCount == 1 {
   331  				mb := domainset.DomainLinearMatcher(rc.ToDomains)
   332  				ds, err := mb.AppendTo(nil)
   333  				if err != nil {
   334  					return Route{}, err
   335  				}
   336  				domainSets[0] = ds
   337  			}
   338  
   339  			for i, tds := range rc.ToDomainSets {
   340  				ds, ok := domainSetMap[tds]
   341  				if !ok {
   342  					return Route{}, fmt.Errorf("domain set not found: %s", tds)
   343  				}
   344  				domainSets[defaultDomainSetCount+i] = ds
   345  			}
   346  
   347  			if len(rc.ToMatchedDomainExpectedPrefixes) > 0 || len(rc.ToMatchedDomainExpectedPrefixSets) > 0 || len(rc.ToMatchedDomainExpectedGeoIPCountries) > 0 {
   348  				var expectedIPCriterionGroup CriterionGroupOR
   349  
   350  				if len(rc.ToMatchedDomainExpectedPrefixes) > 0 || len(rc.ToMatchedDomainExpectedPrefixSets) > 0 {
   351  					var sb netipx.IPSetBuilder
   352  
   353  					for _, prefix := range rc.ToMatchedDomainExpectedPrefixes {
   354  						sb.AddPrefix(prefix)
   355  					}
   356  
   357  					for _, prefixSet := range rc.ToMatchedDomainExpectedPrefixSets {
   358  						s, ok := prefixSetMap[prefixSet]
   359  						if !ok {
   360  							return Route{}, fmt.Errorf("prefix set not found: %s", prefixSet)
   361  						}
   362  						sb.AddSet(s)
   363  					}
   364  
   365  					expectedIPSet, err := sb.IPSet()
   366  					if err != nil {
   367  						return Route{}, fmt.Errorf("failed to build expectedIPSet: %w", err)
   368  					}
   369  
   370  					expectedIPCriterionGroup.AddCriterion(DestResolvedIPCriterion{expectedIPSet, resolvers}, rc.InvertToMatchedDomainExpectedPrefixes)
   371  				}
   372  
   373  				if len(rc.ToMatchedDomainExpectedGeoIPCountries) > 0 {
   374  					expectedIPCriterionGroup.AddCriterion(DestResolvedGeoIPCountryCriterion{
   375  						countries: rc.ToMatchedDomainExpectedGeoIPCountries,
   376  						geoip:     geoip,
   377  						logger:    logger,
   378  						resolvers: resolvers,
   379  					}, rc.InvertToMatchedDomainExpectedGeoIPCountries)
   380  				}
   381  
   382  				group.AddCriterion(DestDomainExpectedIPCriterion{domainSets, expectedIPCriterionGroup.Criterion()}, rc.InvertToDomains)
   383  			} else {
   384  				group.AddCriterion(DestDomainCriterion(domainSets), rc.InvertToDomains)
   385  			}
   386  		}
   387  
   388  		if len(rc.ToPrefixes) > 0 || len(rc.ToPrefixSets) > 0 {
   389  			var sb netipx.IPSetBuilder
   390  
   391  			for _, prefix := range rc.ToPrefixes {
   392  				sb.AddPrefix(prefix)
   393  			}
   394  
   395  			for _, prefixSet := range rc.ToPrefixSets {
   396  				s, ok := prefixSetMap[prefixSet]
   397  				if !ok {
   398  					return Route{}, fmt.Errorf("prefix set not found: %s", prefixSet)
   399  				}
   400  				sb.AddSet(s)
   401  			}
   402  
   403  			destIPSet, err := sb.IPSet()
   404  			if err != nil {
   405  				return Route{}, fmt.Errorf("failed to build destIPSet: %w", err)
   406  			}
   407  
   408  			if rc.DisableNameResolutionForIPRules {
   409  				group.AddCriterion((*DestIPCriterion)(destIPSet), rc.InvertToPrefixes)
   410  			} else {
   411  				group.AddCriterion(DestResolvedIPCriterion{destIPSet, resolvers}, rc.InvertToPrefixes)
   412  			}
   413  		}
   414  
   415  		if len(rc.ToGeoIPCountries) > 0 {
   416  			if rc.DisableNameResolutionForIPRules {
   417  				group.AddCriterion(DestGeoIPCountryCriterion{
   418  					countries: rc.ToGeoIPCountries,
   419  					geoip:     geoip,
   420  					logger:    logger,
   421  				}, rc.InvertToGeoIPCountries)
   422  			} else {
   423  				group.AddCriterion(DestResolvedGeoIPCountryCriterion{
   424  					countries: rc.ToGeoIPCountries,
   425  					geoip:     geoip,
   426  					logger:    logger,
   427  					resolvers: resolvers,
   428  				}, rc.InvertToGeoIPCountries)
   429  			}
   430  		}
   431  
   432  		route.criteria = group.AppendTo(route.criteria)
   433  	}
   434  
   435  	return route, nil
   436  }
   437  
   438  // Route controls which client a request is routed to.
   439  type Route struct {
   440  	name      string
   441  	criteria  []Criterion
   442  	tcpClient zerocopy.TCPClient
   443  	udpClient zerocopy.UDPClient
   444  }
   445  
   446  // String returns the name of the route.
   447  func (r *Route) String() string {
   448  	return r.name
   449  }
   450  
   451  // AddCriterion adds a criterion to the route.
   452  func (r *Route) AddCriterion(criterion Criterion, invert bool) {
   453  	if invert {
   454  		criterion = InvertedCriterion{Inner: criterion}
   455  	}
   456  	r.criteria = append(r.criteria, criterion)
   457  }
   458  
   459  // Match returns whether the request matches the route.
   460  func (r *Route) Match(network protocol, requestInfo RequestInfo) (bool, error) {
   461  	for _, criterion := range r.criteria {
   462  		met, err := criterion.Meet(network, requestInfo)
   463  		if !met {
   464  			return false, err
   465  		}
   466  	}
   467  	return true, nil
   468  }
   469  
   470  // TCPClient returns the TCP client to use for the request.
   471  func (r *Route) TCPClient() (zerocopy.TCPClient, error) {
   472  	if r.tcpClient == nil {
   473  		return nil, ErrRejected
   474  	}
   475  	return r.tcpClient, nil
   476  }
   477  
   478  // UDPClient returns the UDP client to use for the request.
   479  func (r *Route) UDPClient() (zerocopy.UDPClient, error) {
   480  	if r.udpClient == nil {
   481  		return nil, ErrRejected
   482  	}
   483  	return r.udpClient, nil
   484  }
   485  
   486  // Criterion is used by [Route] to determine whether a request matches the route.
   487  type Criterion interface {
   488  	// Meet returns whether the request meets the criterion.
   489  	Meet(network protocol, requestInfo RequestInfo) (bool, error)
   490  }
   491  
   492  // InvertedCriterion is like the inner criterion, but inverted.
   493  type InvertedCriterion struct {
   494  	Inner Criterion
   495  }
   496  
   497  // Meet implements the Criterion Meet method.
   498  func (c InvertedCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   499  	met, err := c.Inner.Meet(network, requestInfo)
   500  	if err != nil {
   501  		return false, err
   502  	}
   503  	return !met, nil
   504  }
   505  
   506  // CriterionGroupOR groups multiple criteria together with OR logic.
   507  type CriterionGroupOR struct {
   508  	Criteria []Criterion
   509  }
   510  
   511  // AddCriterion adds a criterion to the group.
   512  func (g *CriterionGroupOR) AddCriterion(criterion Criterion, invert bool) {
   513  	if invert {
   514  		criterion = InvertedCriterion{Inner: criterion}
   515  	}
   516  	g.Criteria = append(g.Criteria, criterion)
   517  }
   518  
   519  // Meet returns whether the request meets any of the criteria.
   520  func (g CriterionGroupOR) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   521  	for _, criterion := range g.Criteria {
   522  		met, err := criterion.Meet(network, requestInfo)
   523  		if err != nil {
   524  			return false, err
   525  		}
   526  		if met {
   527  			return true, nil
   528  		}
   529  	}
   530  	return false, nil
   531  }
   532  
   533  // Criterion returns a single criterion that represents the group, or nil if the group is empty.
   534  func (g CriterionGroupOR) Criterion() Criterion {
   535  	switch len(g.Criteria) {
   536  	case 0:
   537  		return nil
   538  	case 1:
   539  		return g.Criteria[0]
   540  	default:
   541  		return g
   542  	}
   543  }
   544  
   545  // AppendTo appends the group to the criterion slice.
   546  // When there are more than one criterion in the group, the group itself is appended.
   547  // When there is only one criterion in the group, the criterion is appended directly.
   548  // When there are no criteria in the group, the criterion slice is returned unchanged.
   549  func (g CriterionGroupOR) AppendTo(criteria []Criterion) []Criterion {
   550  	switch len(g.Criteria) {
   551  	case 0:
   552  		return criteria
   553  	case 1:
   554  		return append(criteria, g.Criteria[0])
   555  	default:
   556  		return append(criteria, g)
   557  	}
   558  }
   559  
   560  type protocol byte
   561  
   562  const (
   563  	protocolTCP protocol = iota
   564  	protocolUDP
   565  )
   566  
   567  // RequestInfo contains information about a request that can be met by one or more criteria.
   568  type RequestInfo struct {
   569  	ServerIndex    int
   570  	Username       string
   571  	SourceAddrPort netip.AddrPort
   572  	TargetAddr     conn.Addr
   573  }
   574  
   575  // NetworkTCPCriterion restricts the network to TCP.
   576  type NetworkTCPCriterion struct{}
   577  
   578  // Meet implements the Criterion Meet method.
   579  func (NetworkTCPCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   580  	return network == protocolTCP, nil
   581  }
   582  
   583  // NetworkUDPCriterion restricts the network to UDP.
   584  type NetworkUDPCriterion struct{}
   585  
   586  // Meet implements the Criterion Meet method.
   587  func (NetworkUDPCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   588  	return network == protocolUDP, nil
   589  }
   590  
   591  // SourceServerCriterion restricts the source server.
   592  type SourceServerCriterion bitset.BitSet
   593  
   594  // Meet implements the Criterion Meet method.
   595  func (c SourceServerCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   596  	return bitset.BitSet(c).IsSet(uint(requestInfo.ServerIndex)), nil
   597  }
   598  
   599  // SourceUserCriterion restricts the source user.
   600  type SourceUserCriterion []string
   601  
   602  // Meet implements the Criterion Meet method.
   603  func (c SourceUserCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   604  	return slices.Contains(c, requestInfo.Username), nil
   605  }
   606  
   607  // SourcePortCriterion restricts the source port.
   608  type SourcePortCriterion uint16
   609  
   610  // Meet implements the Criterion Meet method.
   611  func (c SourcePortCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   612  	return uint16(c) == requestInfo.SourceAddrPort.Port(), nil
   613  }
   614  
   615  // SourcePortRangeSetCriterion restricts the source port to ports in a port range set.
   616  type SourcePortRangeSetCriterion portset.PortRangeSet
   617  
   618  // Meet implements the Criterion Meet method.
   619  func (c SourcePortRangeSetCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   620  	return portset.PortRangeSet(c).Contains(requestInfo.SourceAddrPort.Port()), nil
   621  }
   622  
   623  // SourcePortSetCriterion restricts the source port to ports in a port set.
   624  type SourcePortSetCriterion portset.PortSet
   625  
   626  // Meet implements the Criterion Meet method.
   627  func (c *SourcePortSetCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   628  	return (*portset.PortSet)(c).Contains(requestInfo.SourceAddrPort.Port()), nil
   629  }
   630  
   631  // SourceIPCriterion restricts the source IP address.
   632  type SourceIPCriterion netipx.IPSet
   633  
   634  // Meet implements the Criterion Meet method.
   635  func (c *SourceIPCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   636  	return (*netipx.IPSet)(c).Contains(requestInfo.SourceAddrPort.Addr().Unmap()), nil
   637  }
   638  
   639  // SourceGeoIPCountryCriterion restricts the source IP address by GeoIP country.
   640  type SourceGeoIPCountryCriterion struct {
   641  	countries []string
   642  	geoip     *geoip2.Reader
   643  	logger    *zap.Logger
   644  }
   645  
   646  // Meet implements the Criterion Meet method.
   647  func (c SourceGeoIPCountryCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   648  	return matchAddrToGeoIPCountries(c.countries, requestInfo.SourceAddrPort.Addr(), c.geoip, c.logger)
   649  }
   650  
   651  // DestPortCriterion restricts the destination port.
   652  type DestPortCriterion uint16
   653  
   654  // Meet implements the Criterion Meet method.
   655  func (c DestPortCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   656  	return uint16(c) == requestInfo.TargetAddr.Port(), nil
   657  }
   658  
   659  // DestPortRangeSetCriterion restricts the destination port to ports in a port range set.
   660  type DestPortRangeSetCriterion portset.PortRangeSet
   661  
   662  // Meet implements the Criterion Meet method.
   663  func (c DestPortRangeSetCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   664  	return portset.PortRangeSet(c).Contains(requestInfo.TargetAddr.Port()), nil
   665  }
   666  
   667  // DestPortSetCriterion restricts the destination port to ports in a port set.
   668  type DestPortSetCriterion portset.PortSet
   669  
   670  // Meet implements the Criterion Meet method.
   671  func (c *DestPortSetCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   672  	return (*portset.PortSet)(c).Contains(requestInfo.TargetAddr.Port()), nil
   673  }
   674  
   675  // DestDomainCriterion restricts the destination domain.
   676  type DestDomainCriterion []domainset.DomainSet
   677  
   678  // Meet implements the Criterion Meet method.
   679  func (c DestDomainCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   680  	if requestInfo.TargetAddr.IsIP() {
   681  		return false, nil
   682  	}
   683  	return matchDomainToDomainSets(c, requestInfo.TargetAddr.Domain()), nil
   684  }
   685  
   686  // DestDomainExpectedIPCriterion restricts the destination domain and its resolved IP address.
   687  type DestDomainExpectedIPCriterion struct {
   688  	destDomainCriterion DestDomainCriterion
   689  	expectedIPCriterion Criterion
   690  }
   691  
   692  // Meet implements the Criterion Meet method.
   693  func (c DestDomainExpectedIPCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   694  	met, err := c.destDomainCriterion.Meet(network, requestInfo)
   695  	if !met {
   696  		return false, err
   697  	}
   698  	return c.expectedIPCriterion.Meet(network, requestInfo)
   699  }
   700  
   701  // DestIPCriterion restricts the destination IP address.
   702  type DestIPCriterion netipx.IPSet
   703  
   704  // Meet implements the Criterion Meet method.
   705  func (c *DestIPCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   706  	if !requestInfo.TargetAddr.IsIP() {
   707  		return false, nil
   708  	}
   709  	return (*netipx.IPSet)(c).Contains(requestInfo.TargetAddr.IP().Unmap()), nil
   710  }
   711  
   712  // DestResolvedIPCriterion restricts the destination IP address or the destination domain's resolved IP address.
   713  type DestResolvedIPCriterion struct {
   714  	ipSet     *netipx.IPSet
   715  	resolvers []*dns.Resolver
   716  }
   717  
   718  // Meet implements the Criterion Meet method.
   719  func (c DestResolvedIPCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   720  	if requestInfo.TargetAddr.IsIP() {
   721  		return c.ipSet.Contains(requestInfo.TargetAddr.IP().Unmap()), nil
   722  	}
   723  	return matchDomainToIPSet(c.resolvers, requestInfo.TargetAddr.Domain(), c.ipSet)
   724  }
   725  
   726  // DestGeoIPCountryCriterion restricts the destination IP address by GeoIP country.
   727  type DestGeoIPCountryCriterion struct {
   728  	countries []string
   729  	geoip     *geoip2.Reader
   730  	logger    *zap.Logger
   731  }
   732  
   733  // Meet implements the Criterion Meet method.
   734  func (c DestGeoIPCountryCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   735  	if !requestInfo.TargetAddr.IsIP() {
   736  		return false, nil
   737  	}
   738  	return matchAddrToGeoIPCountries(c.countries, requestInfo.TargetAddr.IP(), c.geoip, c.logger)
   739  }
   740  
   741  // DestResolvedGeoIPCountryCriterion restricts the destination IP address or the destination domain's resolved IP address by GeoIP country.
   742  type DestResolvedGeoIPCountryCriterion struct {
   743  	countries []string
   744  	geoip     *geoip2.Reader
   745  	logger    *zap.Logger
   746  	resolvers []*dns.Resolver
   747  }
   748  
   749  // Meet implements the Criterion Meet method.
   750  func (c DestResolvedGeoIPCountryCriterion) Meet(network protocol, requestInfo RequestInfo) (bool, error) {
   751  	if requestInfo.TargetAddr.IsIP() {
   752  		return matchAddrToGeoIPCountries(c.countries, requestInfo.TargetAddr.IP(), c.geoip, c.logger)
   753  	}
   754  	return matchDomainToGeoIPCountries(c.resolvers, requestInfo.TargetAddr.Domain(), c.countries, c.geoip, c.logger)
   755  }
   756  
   757  func matchAddrToGeoIPCountries(countries []string, addr netip.Addr, geoip *geoip2.Reader, logger *zap.Logger) (bool, error) {
   758  	country, err := geoip.Country(addr.AsSlice())
   759  	if err != nil {
   760  		return false, err
   761  	}
   762  	if ce := logger.Check(zap.DebugLevel, "Matched GeoIP country"); ce != nil {
   763  		ce.Write(
   764  			zap.Stringer("ip", addr),
   765  			zap.String("country", country.Country.IsoCode),
   766  		)
   767  	}
   768  	return slices.Contains(countries, country.Country.IsoCode), nil
   769  }
   770  
   771  func matchResultToGeoIPCountries(countries []string, result dns.Result, geoip *geoip2.Reader, logger *zap.Logger) (bool, error) {
   772  	for _, v6 := range result.IPv6 {
   773  		return matchAddrToGeoIPCountries(countries, v6, geoip, logger)
   774  	}
   775  	for _, v4 := range result.IPv4 {
   776  		return matchAddrToGeoIPCountries(countries, v4, geoip, logger)
   777  	}
   778  	return false, nil
   779  }
   780  
   781  func matchResultToIPSet(ipSet *netipx.IPSet, result dns.Result) bool {
   782  	for _, v6 := range result.IPv6 {
   783  		return ipSet.Contains(v6)
   784  	}
   785  	for _, v4 := range result.IPv4 {
   786  		return ipSet.Contains(v4)
   787  	}
   788  	return false
   789  }
   790  
   791  func lookup(resolvers []*dns.Resolver, domain string) (result dns.Result, err error) {
   792  	for _, resolver := range resolvers {
   793  		result, err = resolver.Lookup(domain)
   794  		if err == dns.ErrLookup {
   795  			continue
   796  		}
   797  		return
   798  	}
   799  	return result, dns.ErrLookup
   800  }
   801  
   802  func matchDomainToDomainSets(domainSets []domainset.DomainSet, domain string) bool {
   803  	for _, ds := range domainSets {
   804  		if ds.Match(domain) {
   805  			return true
   806  		}
   807  	}
   808  	return false
   809  }
   810  
   811  func matchDomainToGeoIPCountries(resolvers []*dns.Resolver, domain string, countries []string, geoip *geoip2.Reader, logger *zap.Logger) (bool, error) {
   812  	result, err := lookup(resolvers, domain)
   813  	if err != nil {
   814  		return false, err
   815  	}
   816  	return matchResultToGeoIPCountries(countries, result, geoip, logger)
   817  }
   818  
   819  func matchDomainToIPSet(resolvers []*dns.Resolver, domain string, ipSet *netipx.IPSet) (bool, error) {
   820  	result, err := lookup(resolvers, domain)
   821  	if err != nil {
   822  		return false, err
   823  	}
   824  	return matchResultToIPSet(ipSet, result), nil
   825  }