github.com/sagernet/sing-box@v1.9.0-rc.20/route/rule_dns.go (about)

     1  package route
     2  
     3  import (
     4  	"net/netip"
     5  
     6  	"github.com/sagernet/sing-box/adapter"
     7  	C "github.com/sagernet/sing-box/constant"
     8  	"github.com/sagernet/sing-box/log"
     9  	"github.com/sagernet/sing-box/option"
    10  	"github.com/sagernet/sing/common"
    11  	E "github.com/sagernet/sing/common/exceptions"
    12  )
    13  
    14  func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule, checkServer bool) (adapter.DNSRule, error) {
    15  	switch options.Type {
    16  	case "", C.RuleTypeDefault:
    17  		if !options.DefaultOptions.IsValid() {
    18  			return nil, E.New("missing conditions")
    19  		}
    20  		if options.DefaultOptions.Server == "" && checkServer {
    21  			return nil, E.New("missing server field")
    22  		}
    23  		return NewDefaultDNSRule(router, logger, options.DefaultOptions)
    24  	case C.RuleTypeLogical:
    25  		if !options.LogicalOptions.IsValid() {
    26  			return nil, E.New("missing conditions")
    27  		}
    28  		if options.LogicalOptions.Server == "" && checkServer {
    29  			return nil, E.New("missing server field")
    30  		}
    31  		return NewLogicalDNSRule(router, logger, options.LogicalOptions)
    32  	default:
    33  		return nil, E.New("unknown rule type: ", options.Type)
    34  	}
    35  }
    36  
    37  var _ adapter.DNSRule = (*DefaultDNSRule)(nil)
    38  
    39  type DefaultDNSRule struct {
    40  	abstractDefaultRule
    41  	disableCache bool
    42  	rewriteTTL   *uint32
    43  	clientSubnet *netip.Prefix
    44  }
    45  
    46  func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) {
    47  	rule := &DefaultDNSRule{
    48  		abstractDefaultRule: abstractDefaultRule{
    49  			invert:   options.Invert,
    50  			outbound: options.Server,
    51  		},
    52  		disableCache: options.DisableCache,
    53  		rewriteTTL:   options.RewriteTTL,
    54  		clientSubnet: (*netip.Prefix)(options.ClientSubnet),
    55  	}
    56  	if len(options.Inbound) > 0 {
    57  		item := NewInboundRule(options.Inbound)
    58  		rule.items = append(rule.items, item)
    59  		rule.allItems = append(rule.allItems, item)
    60  	}
    61  	if options.IPVersion > 0 {
    62  		switch options.IPVersion {
    63  		case 4, 6:
    64  			item := NewIPVersionItem(options.IPVersion == 6)
    65  			rule.items = append(rule.items, item)
    66  			rule.allItems = append(rule.allItems, item)
    67  		default:
    68  			return nil, E.New("invalid ip version: ", options.IPVersion)
    69  		}
    70  	}
    71  	if len(options.QueryType) > 0 {
    72  		item := NewQueryTypeItem(options.QueryType)
    73  		rule.items = append(rule.items, item)
    74  		rule.allItems = append(rule.allItems, item)
    75  	}
    76  	if len(options.Network) > 0 {
    77  		item := NewNetworkItem(options.Network)
    78  		rule.items = append(rule.items, item)
    79  		rule.allItems = append(rule.allItems, item)
    80  	}
    81  	if len(options.AuthUser) > 0 {
    82  		item := NewAuthUserItem(options.AuthUser)
    83  		rule.items = append(rule.items, item)
    84  		rule.allItems = append(rule.allItems, item)
    85  	}
    86  	if len(options.Protocol) > 0 {
    87  		item := NewProtocolItem(options.Protocol)
    88  		rule.items = append(rule.items, item)
    89  		rule.allItems = append(rule.allItems, item)
    90  	}
    91  	if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 {
    92  		item := NewDomainItem(options.Domain, options.DomainSuffix)
    93  		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
    94  		rule.allItems = append(rule.allItems, item)
    95  	}
    96  	if len(options.DomainKeyword) > 0 {
    97  		item := NewDomainKeywordItem(options.DomainKeyword)
    98  		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
    99  		rule.allItems = append(rule.allItems, item)
   100  	}
   101  	if len(options.DomainRegex) > 0 {
   102  		item, err := NewDomainRegexItem(options.DomainRegex)
   103  		if err != nil {
   104  			return nil, E.Cause(err, "domain_regex")
   105  		}
   106  		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
   107  		rule.allItems = append(rule.allItems, item)
   108  	}
   109  	if len(options.Geosite) > 0 {
   110  		item := NewGeositeItem(router, logger, options.Geosite)
   111  		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
   112  		rule.allItems = append(rule.allItems, item)
   113  	}
   114  	if len(options.SourceGeoIP) > 0 {
   115  		item := NewGeoIPItem(router, logger, true, options.SourceGeoIP)
   116  		rule.sourceAddressItems = append(rule.sourceAddressItems, item)
   117  		rule.allItems = append(rule.allItems, item)
   118  	}
   119  	if len(options.GeoIP) > 0 {
   120  		item := NewGeoIPItem(router, logger, false, options.GeoIP)
   121  		rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
   122  		rule.allItems = append(rule.allItems, item)
   123  	}
   124  	if len(options.SourceIPCIDR) > 0 {
   125  		item, err := NewIPCIDRItem(true, options.SourceIPCIDR)
   126  		if err != nil {
   127  			return nil, E.Cause(err, "source_ip_cidr")
   128  		}
   129  		rule.sourceAddressItems = append(rule.sourceAddressItems, item)
   130  		rule.allItems = append(rule.allItems, item)
   131  	}
   132  	if len(options.IPCIDR) > 0 {
   133  		item, err := NewIPCIDRItem(false, options.IPCIDR)
   134  		if err != nil {
   135  			return nil, E.Cause(err, "ip_cidr")
   136  		}
   137  		rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
   138  		rule.allItems = append(rule.allItems, item)
   139  	}
   140  	if options.SourceIPIsPrivate {
   141  		item := NewIPIsPrivateItem(true)
   142  		rule.sourceAddressItems = append(rule.sourceAddressItems, item)
   143  		rule.allItems = append(rule.allItems, item)
   144  	}
   145  	if options.IPIsPrivate {
   146  		item := NewIPIsPrivateItem(false)
   147  		rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
   148  		rule.allItems = append(rule.allItems, item)
   149  	}
   150  	if len(options.SourcePort) > 0 {
   151  		item := NewPortItem(true, options.SourcePort)
   152  		rule.sourcePortItems = append(rule.sourcePortItems, item)
   153  		rule.allItems = append(rule.allItems, item)
   154  	}
   155  	if len(options.SourcePortRange) > 0 {
   156  		item, err := NewPortRangeItem(true, options.SourcePortRange)
   157  		if err != nil {
   158  			return nil, E.Cause(err, "source_port_range")
   159  		}
   160  		rule.sourcePortItems = append(rule.sourcePortItems, item)
   161  		rule.allItems = append(rule.allItems, item)
   162  	}
   163  	if len(options.Port) > 0 {
   164  		item := NewPortItem(false, options.Port)
   165  		rule.destinationPortItems = append(rule.destinationPortItems, item)
   166  		rule.allItems = append(rule.allItems, item)
   167  	}
   168  	if len(options.PortRange) > 0 {
   169  		item, err := NewPortRangeItem(false, options.PortRange)
   170  		if err != nil {
   171  			return nil, E.Cause(err, "port_range")
   172  		}
   173  		rule.destinationPortItems = append(rule.destinationPortItems, item)
   174  		rule.allItems = append(rule.allItems, item)
   175  	}
   176  	if len(options.ProcessName) > 0 {
   177  		item := NewProcessItem(options.ProcessName)
   178  		rule.items = append(rule.items, item)
   179  		rule.allItems = append(rule.allItems, item)
   180  	}
   181  	if len(options.ProcessPath) > 0 {
   182  		item := NewProcessPathItem(options.ProcessPath)
   183  		rule.items = append(rule.items, item)
   184  		rule.allItems = append(rule.allItems, item)
   185  	}
   186  	if len(options.PackageName) > 0 {
   187  		item := NewPackageNameItem(options.PackageName)
   188  		rule.items = append(rule.items, item)
   189  		rule.allItems = append(rule.allItems, item)
   190  	}
   191  	if len(options.User) > 0 {
   192  		item := NewUserItem(options.User)
   193  		rule.items = append(rule.items, item)
   194  		rule.allItems = append(rule.allItems, item)
   195  	}
   196  	if len(options.UserID) > 0 {
   197  		item := NewUserIDItem(options.UserID)
   198  		rule.items = append(rule.items, item)
   199  		rule.allItems = append(rule.allItems, item)
   200  	}
   201  	if len(options.Outbound) > 0 {
   202  		item := NewOutboundRule(options.Outbound)
   203  		rule.items = append(rule.items, item)
   204  		rule.allItems = append(rule.allItems, item)
   205  	}
   206  	if options.ClashMode != "" {
   207  		item := NewClashModeItem(router, options.ClashMode)
   208  		rule.items = append(rule.items, item)
   209  		rule.allItems = append(rule.allItems, item)
   210  	}
   211  	if len(options.WIFISSID) > 0 {
   212  		item := NewWIFISSIDItem(router, options.WIFISSID)
   213  		rule.items = append(rule.items, item)
   214  		rule.allItems = append(rule.allItems, item)
   215  	}
   216  	if len(options.WIFIBSSID) > 0 {
   217  		item := NewWIFIBSSIDItem(router, options.WIFIBSSID)
   218  		rule.items = append(rule.items, item)
   219  		rule.allItems = append(rule.allItems, item)
   220  	}
   221  	if len(options.RuleSet) > 0 {
   222  		item := NewRuleSetItem(router, options.RuleSet, options.RuleSetIPCIDRMatchSource)
   223  		rule.items = append(rule.items, item)
   224  		rule.allItems = append(rule.allItems, item)
   225  	}
   226  	return rule, nil
   227  }
   228  
   229  func (r *DefaultDNSRule) DisableCache() bool {
   230  	return r.disableCache
   231  }
   232  
   233  func (r *DefaultDNSRule) RewriteTTL() *uint32 {
   234  	return r.rewriteTTL
   235  }
   236  
   237  func (r *DefaultDNSRule) ClientSubnet() *netip.Prefix {
   238  	return r.clientSubnet
   239  }
   240  
   241  func (r *DefaultDNSRule) WithAddressLimit() bool {
   242  	if len(r.destinationIPCIDRItems) > 0 {
   243  		return true
   244  	}
   245  	for _, rawRule := range r.items {
   246  		ruleSet, isRuleSet := rawRule.(*RuleSetItem)
   247  		if !isRuleSet {
   248  			continue
   249  		}
   250  		if ruleSet.ContainsDestinationIPCIDRRule() {
   251  			return true
   252  		}
   253  	}
   254  	return false
   255  }
   256  
   257  func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool {
   258  	metadata.IgnoreDestinationIPCIDRMatch = true
   259  	defer func() {
   260  		metadata.IgnoreDestinationIPCIDRMatch = false
   261  	}()
   262  	return r.abstractDefaultRule.Match(metadata)
   263  }
   264  
   265  func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {
   266  	return r.abstractDefaultRule.Match(metadata)
   267  }
   268  
   269  var _ adapter.DNSRule = (*LogicalDNSRule)(nil)
   270  
   271  type LogicalDNSRule struct {
   272  	abstractLogicalRule
   273  	disableCache bool
   274  	rewriteTTL   *uint32
   275  	clientSubnet *netip.Prefix
   276  }
   277  
   278  func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) {
   279  	r := &LogicalDNSRule{
   280  		abstractLogicalRule: abstractLogicalRule{
   281  			rules:    make([]adapter.HeadlessRule, len(options.Rules)),
   282  			invert:   options.Invert,
   283  			outbound: options.Server,
   284  		},
   285  		disableCache: options.DisableCache,
   286  		rewriteTTL:   options.RewriteTTL,
   287  		clientSubnet: (*netip.Prefix)(options.ClientSubnet),
   288  	}
   289  	switch options.Mode {
   290  	case C.LogicalTypeAnd:
   291  		r.mode = C.LogicalTypeAnd
   292  	case C.LogicalTypeOr:
   293  		r.mode = C.LogicalTypeOr
   294  	default:
   295  		return nil, E.New("unknown logical mode: ", options.Mode)
   296  	}
   297  	for i, subRule := range options.Rules {
   298  		rule, err := NewDNSRule(router, logger, subRule, false)
   299  		if err != nil {
   300  			return nil, E.Cause(err, "sub rule[", i, "]")
   301  		}
   302  		r.rules[i] = rule
   303  	}
   304  	return r, nil
   305  }
   306  
   307  func (r *LogicalDNSRule) DisableCache() bool {
   308  	return r.disableCache
   309  }
   310  
   311  func (r *LogicalDNSRule) RewriteTTL() *uint32 {
   312  	return r.rewriteTTL
   313  }
   314  
   315  func (r *LogicalDNSRule) ClientSubnet() *netip.Prefix {
   316  	return r.clientSubnet
   317  }
   318  
   319  func (r *LogicalDNSRule) WithAddressLimit() bool {
   320  	for _, rawRule := range r.rules {
   321  		switch rule := rawRule.(type) {
   322  		case *DefaultDNSRule:
   323  			if rule.WithAddressLimit() {
   324  				return true
   325  			}
   326  		case *LogicalDNSRule:
   327  			if rule.WithAddressLimit() {
   328  				return true
   329  			}
   330  		}
   331  	}
   332  	return false
   333  }
   334  
   335  func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
   336  	if r.mode == C.LogicalTypeAnd {
   337  		return common.All(r.rules, func(it adapter.HeadlessRule) bool {
   338  			metadata.ResetRuleCache()
   339  			return it.(adapter.DNSRule).Match(metadata)
   340  		}) != r.invert
   341  	} else {
   342  		return common.Any(r.rules, func(it adapter.HeadlessRule) bool {
   343  			metadata.ResetRuleCache()
   344  			return it.(adapter.DNSRule).Match(metadata)
   345  		}) != r.invert
   346  	}
   347  }
   348  
   349  func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {
   350  	if r.mode == C.LogicalTypeAnd {
   351  		return common.All(r.rules, func(it adapter.HeadlessRule) bool {
   352  			metadata.ResetRuleCache()
   353  			return it.(adapter.DNSRule).MatchAddressLimit(metadata)
   354  		}) != r.invert
   355  	} else {
   356  		return common.Any(r.rules, func(it adapter.HeadlessRule) bool {
   357  			metadata.ResetRuleCache()
   358  			return it.(adapter.DNSRule).MatchAddressLimit(metadata)
   359  		}) != r.invert
   360  	}
   361  }