github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/router/route.go (about)

     1  package router
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net/netip"
     8  	"slices"
     9  
    10  	"github.com/database64128/shadowsocks-go/bitset"
    11  	"github.com/database64128/shadowsocks-go/conn"
    12  	"github.com/database64128/shadowsocks-go/dns"
    13  	"github.com/database64128/shadowsocks-go/domainset"
    14  	"github.com/database64128/shadowsocks-go/portset"
    15  	"github.com/database64128/shadowsocks-go/zerocopy"
    16  	"github.com/oschwald/geoip2-golang"
    17  	"go.uber.org/zap"
    18  	"go4.org/netipx"
    19  )
    20  
    21  // ErrRejected is a special error that indicates the request is rejected.
    22  var ErrRejected = errors.New("rejected")
    23  
    24  var (
    25  	errNoAvailableResolvers  = errors.New("no available resolvers")
    26  	errPointlessPortCriteria = errors.New("matching all ports is equivalent to not having any port filtering rules")
    27  )
    28  
    29  // RouteConfig is a routing rule.
    30  type RouteConfig struct {
    31  	// Name of this route. Used in logs to identify matched routes.
    32  	Name string `json:"name"`
    33  
    34  	// Apply this route to "tcp" or "udp" only. If empty, match all requests.
    35  	Network string `json:"network"`
    36  
    37  	// Route matched requests to this client. Must not be empty.
    38  	Client string `json:"client"`
    39  
    40  	// When matching a domain target to IP prefixes, use this resolver to resolve the domain name.
    41  	// If unspecified, use all resolvers by order.
    42  	Resolver string `json:"resolver"`
    43  
    44  	// Match requests from these servers. If empty, match all requests.
    45  	FromServers []string `json:"fromServers"`
    46  
    47  	// Match requests from these users. If empty, match all requests.
    48  	FromUsers []string `json:"fromUsers"`
    49  
    50  	// Match requests from these ports. If empty, match all requests.
    51  	FromPorts []uint16 `json:"fromPorts"`
    52  
    53  	// Match requests from these ports and port ranges. If empty, match all requests.
    54  	FromPortRanges string `json:"fromPortRanges"`
    55  
    56  	// Match requests from IP addresses in these prefixes. If empty, match all requests.
    57  	FromPrefixes []netip.Prefix `json:"fromPrefixes"`
    58  
    59  	// Match requests from IP addresses in these prefix sets. If empty, match all requests.
    60  	FromPrefixSets []string `json:"fromPrefixSets"`
    61  
    62  	// Match requests from IP addresses in these countries. If empty, match all requests.
    63  	FromGeoIPCountries []string `json:"fromGeoIPCountries"`
    64  
    65  	// Match requests to these ports. If empty, match all requests.
    66  	ToPorts []uint16 `json:"toPorts"`
    67  
    68  	// Match requests to these ports and port ranges. If empty, match all requests.
    69  	ToPortRanges string `json:"toPortRanges"`
    70  
    71  	// Match requests to these domain targets. If empty, match all requests.
    72  	ToDomains []string `json:"toDomains"`
    73  
    74  	// Match requests to domains in these domain sets. If empty, match all requests.
    75  	ToDomainSets []string `json:"toDomainSets"`
    76  
    77  	// Require the matched domain target to resolve to IP addresses in these prefixes.
    78  	ToMatchedDomainExpectedPrefixes []netip.Prefix `json:"toMatchedDomainExpectedPrefixes"`
    79  
    80  	// Require the matched domain target to resolve to IP addresses in these prefix sets.
    81  	ToMatchedDomainExpectedPrefixSets []string `json:"toMatchedDomainExpectedPrefixSets"`
    82  
    83  	// Require the matched domain target to resolve to IP addresses in these countries.
    84  	ToMatchedDomainExpectedGeoIPCountries []string `json:"toMatchedDomainExpectedGeoIPCountries"`
    85  
    86  	// Match requests to IP addresses in these prefixes. If empty, match all requests.
    87  	ToPrefixes []netip.Prefix `json:"toPrefixes"`
    88  
    89  	// Match requests to IP addresses in these prefix sets. If empty, match all requests.
    90  	ToPrefixSets []string `json:"toPrefixSets"`
    91  
    92  	// Match requests to IP addresses in these countries. If empty, match all requests.
    93  	ToGeoIPCountries []string `json:"toGeoIPCountries"`
    94  
    95  	// Do not resolve destination domains to match IP rules.
    96  	DisableNameResolutionForIPRules bool `json:"disableNameResolutionForIPRules"`
    97  
    98  	// Invert source server matching logic. Match requests from all servers except those in FromServers.
    99  	InvertFromServers bool `json:"invertFromServers"`
   100  
   101  	// Invert source user matching logic. Match requests from all users except those in FromUsers.
   102  	InvertFromUsers bool `json:"invertFromUsers"`
   103  
   104  	// Invert source IP prefix matching logic. Match requests from all IP prefixes except those in FromPrefixes or FromPrefixSets.
   105  	InvertFromPrefixes bool `json:"invertFromPrefixes"`
   106  
   107  	// Invert source GeoIP country matching logic. Match requests from all countries except those in FromGeoIPCountries.
   108  	InvertFromGeoIPCountries bool `json:"invertFromGeoIPCountries"`
   109  
   110  	// Invert source port matching logic. Match requests from all ports except those in FromPorts.
   111  	InvertFromPorts bool `json:"invertFromPorts"`
   112  
   113  	// Invert destination domain matching logic. Match requests to all domains except those in ToDomains or ToDomainSets.
   114  	InvertToDomains bool `json:"invertToDomains"`
   115  
   116  	// Invert destination domain expected prefix matching logic. Match requests to all domains except those whose resolved IP addresses are in ToMatchedDomainExpectedPrefixes or ToMatchedDomainExpectedPrefixSets.
   117  	InvertToMatchedDomainExpectedPrefixes bool `json:"invertToMatchedDomainExpectedPrefixes"`
   118  
   119  	// Invert destination domain expected GeoIP country matching logic. Match requests to all domains except those whose resolved IP addresses are in ToMatchedDomainExpectedGeoIPCountries.
   120  	InvertToMatchedDomainExpectedGeoIPCountries bool `json:"invertToMatchedDomainExpectedGeoIPCountries"`
   121  
   122  	// Invert destination IP prefix matching logic. Match requests to all IP prefixes except those in ToPrefixes or ToPrefixSets.
   123  	InvertToPrefixes bool `json:"invertToPrefixes"`
   124  
   125  	// Invert destination GeoIP country matching logic. Match requests to all countries except those in ToGeoIPCountries.
   126  	InvertToGeoIPCountries bool `json:"invertToGeoIPCountries"`
   127  
   128  	// Invert destination port matching logic. Match requests to all ports except those in ToPorts.
   129  	InvertToPorts bool `json:"invertToPorts"`
   130  }
   131  
   132  // Route creates a route from the RouteConfig.
   133  func (rc *RouteConfig) Route(geoip *geoip2.Reader, logger *zap.Logger, resolvers []dns.SimpleResolver, resolverMap map[string]dns.SimpleResolver, tcpClientMap map[string]zerocopy.TCPClient, udpClientMap map[string]zerocopy.UDPClient, serverIndexByName map[string]int, domainSetMap map[string]domainset.DomainSet, prefixSetMap map[string]*netipx.IPSet) (Route, error) {
   134  	// Bad name.
   135  	switch rc.Name {
   136  	case "", "default":
   137  		return Route{}, errors.New("route name cannot be empty or 'default'")
   138  	}
   139  
   140  	// Has GeoIP criteria but no GeoIP database.
   141  	if geoip == nil && (len(rc.FromGeoIPCountries) > 0 || len(rc.ToGeoIPCountries) > 0 || len(rc.ToMatchedDomainExpectedGeoIPCountries) > 0) {
   142  		return Route{}, errors.New("missing GeoLite2 country database path")
   143  	}
   144  
   145  	// Needs to resolve domain names but has no resolvers.
   146  	if len(resolvers) == 0 &&
   147  		(len(rc.ToMatchedDomainExpectedPrefixes) > 0 ||
   148  			len(rc.ToMatchedDomainExpectedPrefixSets) > 0 ||
   149  			len(rc.ToMatchedDomainExpectedGeoIPCountries) > 0 ||
   150  			!rc.DisableNameResolutionForIPRules &&
   151  				(len(rc.ToPrefixes) > 0 || len(rc.ToPrefixSets) > 0 || len(rc.ToGeoIPCountries) > 0)) {
   152  		return Route{}, errors.New("missing resolvers for one or more criteria")
   153  	}
   154  
   155  	// Has resolved IP expectations but no destination domain criteria.
   156  	if len(rc.ToDomains) == 0 && len(rc.ToDomainSets) == 0 &&
   157  		(len(rc.ToMatchedDomainExpectedPrefixes) > 0 ||
   158  			len(rc.ToMatchedDomainExpectedPrefixSets) > 0 ||
   159  			len(rc.ToMatchedDomainExpectedGeoIPCountries) > 0) {
   160  		return Route{}, errors.New("missing destination domain criteria")
   161  	}
   162  
   163  	if rc.Resolver != "" {
   164  		resolver, ok := resolverMap[rc.Resolver]
   165  		if !ok {
   166  			return Route{}, fmt.Errorf("resolver not found: %s", rc.Resolver)
   167  		}
   168  		resolvers = []dns.SimpleResolver{resolver}
   169  	}
   170  
   171  	route := Route{name: rc.Name}
   172  
   173  	switch rc.Network {
   174  	case "":
   175  	case "tcp":
   176  		route.AddCriterion(NetworkTCPCriterion{}, false)
   177  	case "udp":
   178  		route.AddCriterion(NetworkUDPCriterion{}, false)
   179  	default:
   180  		return Route{}, fmt.Errorf("invalid network: %s", rc.Network)
   181  	}
   182  
   183  	if rc.Client != "reject" {
   184  		switch rc.Network {
   185  		case "", "tcp":
   186  			route.tcpClient = tcpClientMap[rc.Client]
   187  			if route.tcpClient == nil {
   188  				return Route{}, fmt.Errorf("TCP client not found: %s", rc.Client)
   189  			}
   190  		}
   191  
   192  		switch rc.Network {
   193  		case "", "udp":
   194  			route.udpClient = udpClientMap[rc.Client]
   195  			if route.udpClient == nil {
   196  				return Route{}, fmt.Errorf("UDP client not found: %s", rc.Client)
   197  			}
   198  		}
   199  	}
   200  
   201  	if len(rc.FromServers) > 0 {
   202  		sourceServerSet := bitset.NewBitSet(uint(len(serverIndexByName)))
   203  
   204  		for _, server := range rc.FromServers {
   205  			index, ok := serverIndexByName[server]
   206  			if !ok {
   207  				return Route{}, fmt.Errorf("server not found: %s", server)
   208  			}
   209  			sourceServerSet.Set(uint(index))
   210  		}
   211  
   212  		route.AddCriterion(SourceServerCriterion(sourceServerSet), rc.InvertFromServers)
   213  	}
   214  
   215  	if len(rc.FromUsers) > 0 {
   216  		route.AddCriterion(SourceUserCriterion(rc.FromUsers), rc.InvertFromUsers)
   217  	}
   218  
   219  	if len(rc.FromPorts) > 0 || rc.FromPortRanges != "" {
   220  		var portSet portset.PortSet
   221  
   222  		for _, port := range rc.FromPorts {
   223  			if port == 0 {
   224  				return Route{}, fmt.Errorf("bad fromPorts: %w", portset.ErrZeroPort)
   225  			}
   226  			portSet.Add(port)
   227  		}
   228  
   229  		if err := portSet.Parse(rc.FromPortRanges); err != nil {
   230  			return Route{}, fmt.Errorf("failed to parse source port ranges: %w", err)
   231  		}
   232  
   233  		portCount := portSet.Count()
   234  		switch portCount {
   235  		case 0:
   236  			panic("unreachable")
   237  		case 1:
   238  			route.AddCriterion(SourcePortCriterion(portSet.First()), rc.InvertFromPorts)
   239  		case 65535:
   240  			return Route{}, fmt.Errorf("bad source port criteria: %w", errPointlessPortCriteria)
   241  		default:
   242  			portRangeCount := portSet.RangeCount()
   243  			if portRangeCount <= 16 {
   244  				route.AddCriterion(SourcePortRangeSetCriterion(portSet.RangeSet()), rc.InvertFromPorts)
   245  			} else {
   246  				sourcePortSetCriterion := SourcePortSetCriterion(portSet)
   247  				route.AddCriterion(&sourcePortSetCriterion, rc.InvertFromPorts)
   248  			}
   249  		}
   250  	}
   251  
   252  	if len(rc.FromPrefixes) > 0 || len(rc.FromPrefixSets) > 0 || len(rc.FromGeoIPCountries) > 0 {
   253  		var group CriterionGroupOR
   254  
   255  		if len(rc.FromPrefixes) > 0 || len(rc.FromPrefixSets) > 0 {
   256  			var sb netipx.IPSetBuilder
   257  
   258  			for _, prefix := range rc.FromPrefixes {
   259  				sb.AddPrefix(prefix)
   260  			}
   261  
   262  			for _, prefixSet := range rc.FromPrefixSets {
   263  				s, ok := prefixSetMap[prefixSet]
   264  				if !ok {
   265  					return Route{}, fmt.Errorf("prefix set not found: %s", prefixSet)
   266  				}
   267  				sb.AddSet(s)
   268  			}
   269  
   270  			sourceIPSet, err := sb.IPSet()
   271  			if err != nil {
   272  				return Route{}, fmt.Errorf("failed to build sourceIPSet: %w", err)
   273  			}
   274  
   275  			group.AddCriterion((*SourceIPCriterion)(sourceIPSet), rc.InvertFromPrefixes)
   276  		}
   277  
   278  		if len(rc.FromGeoIPCountries) > 0 {
   279  			group.AddCriterion(SourceGeoIPCountryCriterion{
   280  				countries: rc.FromGeoIPCountries,
   281  				geoip:     geoip,
   282  				logger:    logger,
   283  			}, rc.InvertFromGeoIPCountries)
   284  		}
   285  
   286  		route.criteria = group.AppendTo(route.criteria)
   287  	}
   288  
   289  	if len(rc.ToPorts) > 0 || rc.ToPortRanges != "" {
   290  		var portSet portset.PortSet
   291  
   292  		for _, port := range rc.ToPorts {
   293  			if port == 0 {
   294  				return Route{}, fmt.Errorf("bad toPorts: %w", portset.ErrZeroPort)
   295  			}
   296  			portSet.Add(port)
   297  		}
   298  
   299  		if err := portSet.Parse(rc.ToPortRanges); err != nil {
   300  			return Route{}, fmt.Errorf("failed to parse destination port ranges: %w", err)
   301  		}
   302  
   303  		portCount := portSet.Count()
   304  		switch portCount {
   305  		case 0:
   306  			panic("unreachable")
   307  		case 1:
   308  			route.AddCriterion(DestPortCriterion(portSet.First()), rc.InvertToPorts)
   309  		case 65535:
   310  			return Route{}, fmt.Errorf("bad destination port criteria: %w", errPointlessPortCriteria)
   311  		default:
   312  			portRangeCount := portSet.RangeCount()
   313  			if portRangeCount <= 16 {
   314  				route.AddCriterion(DestPortRangeSetCriterion(portSet.RangeSet()), rc.InvertToPorts)
   315  			} else {
   316  				destPortSetCriterion := DestPortSetCriterion(portSet)
   317  				route.AddCriterion(&destPortSetCriterion, rc.InvertToPorts)
   318  			}
   319  		}
   320  	}
   321  
   322  	if len(rc.ToDomains) > 0 || len(rc.ToDomainSets) > 0 || len(rc.ToPrefixes) > 0 || len(rc.ToPrefixSets) > 0 || len(rc.ToGeoIPCountries) > 0 {
   323  		var group CriterionGroupOR
   324  
   325  		if len(rc.ToDomains) > 0 || len(rc.ToDomainSets) > 0 {
   326  			var defaultDomainSetCount int
   327  
   328  			if len(rc.ToDomains) > 0 {
   329  				defaultDomainSetCount = 1
   330  			}
   331  
   332  			domainSets := make([]domainset.DomainSet, defaultDomainSetCount+len(rc.ToDomainSets))
   333  
   334  			if defaultDomainSetCount == 1 {
   335  				mb := domainset.DomainLinearMatcher(rc.ToDomains)
   336  				ds, err := mb.AppendTo(nil)
   337  				if err != nil {
   338  					return Route{}, err
   339  				}
   340  				domainSets[0] = ds
   341  			}
   342  
   343  			for i, tds := range rc.ToDomainSets {
   344  				ds, ok := domainSetMap[tds]
   345  				if !ok {
   346  					return Route{}, fmt.Errorf("domain set not found: %s", tds)
   347  				}
   348  				domainSets[defaultDomainSetCount+i] = ds
   349  			}
   350  
   351  			if len(rc.ToMatchedDomainExpectedPrefixes) > 0 || len(rc.ToMatchedDomainExpectedPrefixSets) > 0 || len(rc.ToMatchedDomainExpectedGeoIPCountries) > 0 {
   352  				var expectedIPCriterionGroup CriterionGroupOR
   353  
   354  				if len(rc.ToMatchedDomainExpectedPrefixes) > 0 || len(rc.ToMatchedDomainExpectedPrefixSets) > 0 {
   355  					var sb netipx.IPSetBuilder
   356  
   357  					for _, prefix := range rc.ToMatchedDomainExpectedPrefixes {
   358  						sb.AddPrefix(prefix)
   359  					}
   360  
   361  					for _, prefixSet := range rc.ToMatchedDomainExpectedPrefixSets {
   362  						s, ok := prefixSetMap[prefixSet]
   363  						if !ok {
   364  							return Route{}, fmt.Errorf("prefix set not found: %s", prefixSet)
   365  						}
   366  						sb.AddSet(s)
   367  					}
   368  
   369  					expectedIPSet, err := sb.IPSet()
   370  					if err != nil {
   371  						return Route{}, fmt.Errorf("failed to build expectedIPSet: %w", err)
   372  					}
   373  
   374  					expectedIPCriterionGroup.AddCriterion(DestResolvedIPCriterion{expectedIPSet, resolvers}, rc.InvertToMatchedDomainExpectedPrefixes)
   375  				}
   376  
   377  				if len(rc.ToMatchedDomainExpectedGeoIPCountries) > 0 {
   378  					expectedIPCriterionGroup.AddCriterion(DestResolvedGeoIPCountryCriterion{
   379  						countries: rc.ToMatchedDomainExpectedGeoIPCountries,
   380  						geoip:     geoip,
   381  						logger:    logger,
   382  						resolvers: resolvers,
   383  					}, rc.InvertToMatchedDomainExpectedGeoIPCountries)
   384  				}
   385  
   386  				group.AddCriterion(DestDomainExpectedIPCriterion{domainSets, expectedIPCriterionGroup.Criterion()}, rc.InvertToDomains)
   387  			} else {
   388  				group.AddCriterion(DestDomainCriterion(domainSets), rc.InvertToDomains)
   389  			}
   390  		}
   391  
   392  		if len(rc.ToPrefixes) > 0 || len(rc.ToPrefixSets) > 0 {
   393  			var sb netipx.IPSetBuilder
   394  
   395  			for _, prefix := range rc.ToPrefixes {
   396  				sb.AddPrefix(prefix)
   397  			}
   398  
   399  			for _, prefixSet := range rc.ToPrefixSets {
   400  				s, ok := prefixSetMap[prefixSet]
   401  				if !ok {
   402  					return Route{}, fmt.Errorf("prefix set not found: %s", prefixSet)
   403  				}
   404  				sb.AddSet(s)
   405  			}
   406  
   407  			destIPSet, err := sb.IPSet()
   408  			if err != nil {
   409  				return Route{}, fmt.Errorf("failed to build destIPSet: %w", err)
   410  			}
   411  
   412  			if rc.DisableNameResolutionForIPRules {
   413  				group.AddCriterion((*DestIPCriterion)(destIPSet), rc.InvertToPrefixes)
   414  			} else {
   415  				group.AddCriterion(DestResolvedIPCriterion{destIPSet, resolvers}, rc.InvertToPrefixes)
   416  			}
   417  		}
   418  
   419  		if len(rc.ToGeoIPCountries) > 0 {
   420  			if rc.DisableNameResolutionForIPRules {
   421  				group.AddCriterion(DestGeoIPCountryCriterion{
   422  					countries: rc.ToGeoIPCountries,
   423  					geoip:     geoip,
   424  					logger:    logger,
   425  				}, rc.InvertToGeoIPCountries)
   426  			} else {
   427  				group.AddCriterion(DestResolvedGeoIPCountryCriterion{
   428  					countries: rc.ToGeoIPCountries,
   429  					geoip:     geoip,
   430  					logger:    logger,
   431  					resolvers: resolvers,
   432  				}, rc.InvertToGeoIPCountries)
   433  			}
   434  		}
   435  
   436  		route.criteria = group.AppendTo(route.criteria)
   437  	}
   438  
   439  	return route, nil
   440  }
   441  
   442  // Route controls which client a request is routed to.
   443  type Route struct {
   444  	name      string
   445  	criteria  []Criterion
   446  	tcpClient zerocopy.TCPClient
   447  	udpClient zerocopy.UDPClient
   448  }
   449  
   450  // String returns the name of the route.
   451  func (r *Route) String() string {
   452  	return r.name
   453  }
   454  
   455  // AddCriterion adds a criterion to the route.
   456  func (r *Route) AddCriterion(criterion Criterion, invert bool) {
   457  	if invert {
   458  		criterion = InvertedCriterion{Inner: criterion}
   459  	}
   460  	r.criteria = append(r.criteria, criterion)
   461  }
   462  
   463  // Match returns whether the request matches the route.
   464  func (r *Route) Match(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   465  	for _, criterion := range r.criteria {
   466  		met, err := criterion.Meet(ctx, network, requestInfo)
   467  		if !met {
   468  			return false, err
   469  		}
   470  	}
   471  	return true, nil
   472  }
   473  
   474  // TCPClient returns the TCP client to use for the request.
   475  func (r *Route) TCPClient() (zerocopy.TCPClient, error) {
   476  	if r.tcpClient == nil {
   477  		return nil, ErrRejected
   478  	}
   479  	return r.tcpClient, nil
   480  }
   481  
   482  // UDPClient returns the UDP client to use for the request.
   483  func (r *Route) UDPClient() (zerocopy.UDPClient, error) {
   484  	if r.udpClient == nil {
   485  		return nil, ErrRejected
   486  	}
   487  	return r.udpClient, nil
   488  }
   489  
   490  // Criterion is used by [Route] to determine whether a request matches the route.
   491  type Criterion interface {
   492  	// Meet returns whether the request meets the criterion.
   493  	Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error)
   494  }
   495  
   496  // InvertedCriterion is like the inner criterion, but inverted.
   497  type InvertedCriterion struct {
   498  	Inner Criterion
   499  }
   500  
   501  // Meet implements the Criterion Meet method.
   502  func (c InvertedCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   503  	met, err := c.Inner.Meet(ctx, network, requestInfo)
   504  	if err != nil {
   505  		return false, err
   506  	}
   507  	return !met, nil
   508  }
   509  
   510  // CriterionGroupOR groups multiple criteria together with OR logic.
   511  type CriterionGroupOR struct {
   512  	Criteria []Criterion
   513  }
   514  
   515  // AddCriterion adds a criterion to the group.
   516  func (g *CriterionGroupOR) AddCriterion(criterion Criterion, invert bool) {
   517  	if invert {
   518  		criterion = InvertedCriterion{Inner: criterion}
   519  	}
   520  	g.Criteria = append(g.Criteria, criterion)
   521  }
   522  
   523  // Meet returns whether the request meets any of the criteria.
   524  func (g CriterionGroupOR) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   525  	for _, criterion := range g.Criteria {
   526  		met, err := criterion.Meet(ctx, network, requestInfo)
   527  		if err != nil {
   528  			return false, err
   529  		}
   530  		if met {
   531  			return true, nil
   532  		}
   533  	}
   534  	return false, nil
   535  }
   536  
   537  // Criterion returns a single criterion that represents the group, or nil if the group is empty.
   538  func (g CriterionGroupOR) Criterion() Criterion {
   539  	switch len(g.Criteria) {
   540  	case 0:
   541  		return nil
   542  	case 1:
   543  		return g.Criteria[0]
   544  	default:
   545  		return g
   546  	}
   547  }
   548  
   549  // AppendTo appends the group to the criterion slice.
   550  // When there are more than one criterion in the group, the group itself is appended.
   551  // When there is only one criterion in the group, the criterion is appended directly.
   552  // When there are no criteria in the group, the criterion slice is returned unchanged.
   553  func (g CriterionGroupOR) AppendTo(criteria []Criterion) []Criterion {
   554  	switch len(g.Criteria) {
   555  	case 0:
   556  		return criteria
   557  	case 1:
   558  		return append(criteria, g.Criteria[0])
   559  	default:
   560  		return append(criteria, g)
   561  	}
   562  }
   563  
   564  type protocol byte
   565  
   566  const (
   567  	protocolTCP protocol = iota
   568  	protocolUDP
   569  )
   570  
   571  // RequestInfo contains information about a request that can be met by one or more criteria.
   572  type RequestInfo struct {
   573  	ServerIndex    int
   574  	Username       string
   575  	SourceAddrPort netip.AddrPort
   576  	TargetAddr     conn.Addr
   577  }
   578  
   579  // NetworkTCPCriterion restricts the network to TCP.
   580  type NetworkTCPCriterion struct{}
   581  
   582  // Meet implements the Criterion Meet method.
   583  func (NetworkTCPCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   584  	return network == protocolTCP, nil
   585  }
   586  
   587  // NetworkUDPCriterion restricts the network to UDP.
   588  type NetworkUDPCriterion struct{}
   589  
   590  // Meet implements the Criterion Meet method.
   591  func (NetworkUDPCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   592  	return network == protocolUDP, nil
   593  }
   594  
   595  // SourceServerCriterion restricts the source server.
   596  type SourceServerCriterion bitset.BitSet
   597  
   598  // Meet implements the Criterion Meet method.
   599  func (c SourceServerCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   600  	return bitset.BitSet(c).IsSet(uint(requestInfo.ServerIndex)), nil
   601  }
   602  
   603  // SourceUserCriterion restricts the source user.
   604  type SourceUserCriterion []string
   605  
   606  // Meet implements the Criterion Meet method.
   607  func (c SourceUserCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   608  	return slices.Contains(c, requestInfo.Username), nil
   609  }
   610  
   611  // SourcePortCriterion restricts the source port.
   612  type SourcePortCriterion uint16
   613  
   614  // Meet implements the Criterion Meet method.
   615  func (c SourcePortCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   616  	return uint16(c) == requestInfo.SourceAddrPort.Port(), nil
   617  }
   618  
   619  // SourcePortRangeSetCriterion restricts the source port to ports in a port range set.
   620  type SourcePortRangeSetCriterion portset.PortRangeSet
   621  
   622  // Meet implements the Criterion Meet method.
   623  func (c SourcePortRangeSetCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   624  	return portset.PortRangeSet(c).Contains(requestInfo.SourceAddrPort.Port()), nil
   625  }
   626  
   627  // SourcePortSetCriterion restricts the source port to ports in a port set.
   628  type SourcePortSetCriterion portset.PortSet
   629  
   630  // Meet implements the Criterion Meet method.
   631  func (c *SourcePortSetCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   632  	return (*portset.PortSet)(c).Contains(requestInfo.SourceAddrPort.Port()), nil
   633  }
   634  
   635  // SourceIPCriterion restricts the source IP address.
   636  type SourceIPCriterion netipx.IPSet
   637  
   638  // Meet implements the Criterion Meet method.
   639  func (c *SourceIPCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   640  	return (*netipx.IPSet)(c).Contains(requestInfo.SourceAddrPort.Addr().Unmap()), nil
   641  }
   642  
   643  // SourceGeoIPCountryCriterion restricts the source IP address by GeoIP country.
   644  type SourceGeoIPCountryCriterion struct {
   645  	countries []string
   646  	geoip     *geoip2.Reader
   647  	logger    *zap.Logger
   648  }
   649  
   650  // Meet implements the Criterion Meet method.
   651  func (c SourceGeoIPCountryCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   652  	return matchAddrToGeoIPCountries(c.countries, requestInfo.SourceAddrPort.Addr(), c.geoip, c.logger)
   653  }
   654  
   655  // DestPortCriterion restricts the destination port.
   656  type DestPortCriterion uint16
   657  
   658  // Meet implements the Criterion Meet method.
   659  func (c DestPortCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   660  	return uint16(c) == requestInfo.TargetAddr.Port(), nil
   661  }
   662  
   663  // DestPortRangeSetCriterion restricts the destination port to ports in a port range set.
   664  type DestPortRangeSetCriterion portset.PortRangeSet
   665  
   666  // Meet implements the Criterion Meet method.
   667  func (c DestPortRangeSetCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   668  	return portset.PortRangeSet(c).Contains(requestInfo.TargetAddr.Port()), nil
   669  }
   670  
   671  // DestPortSetCriterion restricts the destination port to ports in a port set.
   672  type DestPortSetCriterion portset.PortSet
   673  
   674  // Meet implements the Criterion Meet method.
   675  func (c *DestPortSetCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   676  	return (*portset.PortSet)(c).Contains(requestInfo.TargetAddr.Port()), nil
   677  }
   678  
   679  // DestDomainCriterion restricts the destination domain.
   680  type DestDomainCriterion []domainset.DomainSet
   681  
   682  // Meet implements the Criterion Meet method.
   683  func (c DestDomainCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   684  	if requestInfo.TargetAddr.IsIP() {
   685  		return false, nil
   686  	}
   687  	return matchDomainToDomainSets(c, requestInfo.TargetAddr.Domain()), nil
   688  }
   689  
   690  // DestDomainExpectedIPCriterion restricts the destination domain and its resolved IP address.
   691  type DestDomainExpectedIPCriterion struct {
   692  	destDomainCriterion DestDomainCriterion
   693  	expectedIPCriterion Criterion
   694  }
   695  
   696  // Meet implements the Criterion Meet method.
   697  func (c DestDomainExpectedIPCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   698  	met, err := c.destDomainCriterion.Meet(ctx, network, requestInfo)
   699  	if !met {
   700  		return false, err
   701  	}
   702  	return c.expectedIPCriterion.Meet(ctx, network, requestInfo)
   703  }
   704  
   705  // DestIPCriterion restricts the destination IP address.
   706  type DestIPCriterion netipx.IPSet
   707  
   708  // Meet implements the Criterion Meet method.
   709  func (c *DestIPCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   710  	if !requestInfo.TargetAddr.IsIP() {
   711  		return false, nil
   712  	}
   713  	return (*netipx.IPSet)(c).Contains(requestInfo.TargetAddr.IP().Unmap()), nil
   714  }
   715  
   716  // DestResolvedIPCriterion restricts the destination IP address or the destination domain's resolved IP address.
   717  type DestResolvedIPCriterion struct {
   718  	ipSet     *netipx.IPSet
   719  	resolvers []dns.SimpleResolver
   720  }
   721  
   722  // Meet implements the Criterion Meet method.
   723  func (c DestResolvedIPCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   724  	if requestInfo.TargetAddr.IsIP() {
   725  		return c.ipSet.Contains(requestInfo.TargetAddr.IP().Unmap()), nil
   726  	}
   727  	return matchDomainToIPSet(ctx, c.resolvers, requestInfo.TargetAddr.Domain(), c.ipSet)
   728  }
   729  
   730  // DestGeoIPCountryCriterion restricts the destination IP address by GeoIP country.
   731  type DestGeoIPCountryCriterion struct {
   732  	countries []string
   733  	geoip     *geoip2.Reader
   734  	logger    *zap.Logger
   735  }
   736  
   737  // Meet implements the Criterion Meet method.
   738  func (c DestGeoIPCountryCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   739  	if !requestInfo.TargetAddr.IsIP() {
   740  		return false, nil
   741  	}
   742  	return matchAddrToGeoIPCountries(c.countries, requestInfo.TargetAddr.IP(), c.geoip, c.logger)
   743  }
   744  
   745  // DestResolvedGeoIPCountryCriterion restricts the destination IP address or the destination domain's resolved IP address by GeoIP country.
   746  type DestResolvedGeoIPCountryCriterion struct {
   747  	countries []string
   748  	geoip     *geoip2.Reader
   749  	logger    *zap.Logger
   750  	resolvers []dns.SimpleResolver
   751  }
   752  
   753  // Meet implements the Criterion Meet method.
   754  func (c DestResolvedGeoIPCountryCriterion) Meet(ctx context.Context, network protocol, requestInfo RequestInfo) (bool, error) {
   755  	if requestInfo.TargetAddr.IsIP() {
   756  		return matchAddrToGeoIPCountries(c.countries, requestInfo.TargetAddr.IP(), c.geoip, c.logger)
   757  	}
   758  	return matchDomainToGeoIPCountries(ctx, c.resolvers, requestInfo.TargetAddr.Domain(), c.countries, c.geoip, c.logger)
   759  }
   760  
   761  func matchAddrToGeoIPCountries(countries []string, addr netip.Addr, geoip *geoip2.Reader, logger *zap.Logger) (bool, error) {
   762  	country, err := geoip.Country(addr.AsSlice())
   763  	if err != nil {
   764  		return false, err
   765  	}
   766  	if ce := logger.Check(zap.DebugLevel, "Matched GeoIP country"); ce != nil {
   767  		ce.Write(
   768  			zap.Stringer("ip", addr),
   769  			zap.String("country", country.Country.IsoCode),
   770  		)
   771  	}
   772  	return slices.Contains(countries, country.Country.IsoCode), nil
   773  }
   774  
   775  func lookup(ctx context.Context, resolvers []dns.SimpleResolver, domain string) (ip netip.Addr, err error) {
   776  	for _, resolver := range resolvers {
   777  		ip, err = resolver.LookupIP(ctx, domain)
   778  		if err == dns.ErrLookup {
   779  			continue
   780  		}
   781  		return
   782  	}
   783  	return ip, errNoAvailableResolvers
   784  }
   785  
   786  func matchDomainToDomainSets(domainSets []domainset.DomainSet, domain string) bool {
   787  	for _, ds := range domainSets {
   788  		if ds.Match(domain) {
   789  			return true
   790  		}
   791  	}
   792  	return false
   793  }
   794  
   795  func matchDomainToGeoIPCountries(ctx context.Context, resolvers []dns.SimpleResolver, domain string, countries []string, geoip *geoip2.Reader, logger *zap.Logger) (bool, error) {
   796  	ip, err := lookup(ctx, resolvers, domain)
   797  	if err != nil {
   798  		return false, err
   799  	}
   800  	return matchAddrToGeoIPCountries(countries, ip, geoip, logger)
   801  }
   802  
   803  func matchDomainToIPSet(ctx context.Context, resolvers []dns.SimpleResolver, domain string, ipSet *netipx.IPSet) (bool, error) {
   804  	ip, err := lookup(ctx, resolvers, domain)
   805  	if err != nil {
   806  		return false, err
   807  	}
   808  	return ipSet.Contains(ip.Unmap()), nil
   809  }