github.com/sagernet/sing-box@v1.2.7/route/rule_dns.go (about)

     1  package route
     2  
     3  import (
     4  	"strings"
     5  
     6  	"github.com/sagernet/sing-box/adapter"
     7  	C "github.com/sagernet/sing-box/constant"
     8  	"github.com/sagernet/sing-box/log"
     9  	"github.com/sagernet/sing-box/option"
    10  	"github.com/sagernet/sing/common"
    11  	E "github.com/sagernet/sing/common/exceptions"
    12  	F "github.com/sagernet/sing/common/format"
    13  	N "github.com/sagernet/sing/common/network"
    14  )
    15  
    16  func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.DNSRule, error) {
    17  	switch options.Type {
    18  	case "", C.RuleTypeDefault:
    19  		if !options.DefaultOptions.IsValid() {
    20  			return nil, E.New("missing conditions")
    21  		}
    22  		if options.DefaultOptions.Server == "" {
    23  			return nil, E.New("missing server field")
    24  		}
    25  		return NewDefaultDNSRule(router, logger, options.DefaultOptions)
    26  	case C.RuleTypeLogical:
    27  		if !options.LogicalOptions.IsValid() {
    28  			return nil, E.New("missing conditions")
    29  		}
    30  		if options.LogicalOptions.Server == "" {
    31  			return nil, E.New("missing server field")
    32  		}
    33  		return NewLogicalDNSRule(router, logger, options.LogicalOptions)
    34  	default:
    35  		return nil, E.New("unknown rule type: ", options.Type)
    36  	}
    37  }
    38  
    39  var _ adapter.DNSRule = (*DefaultDNSRule)(nil)
    40  
    41  type DefaultDNSRule struct {
    42  	items                   []RuleItem
    43  	sourceAddressItems      []RuleItem
    44  	sourcePortItems         []RuleItem
    45  	destinationAddressItems []RuleItem
    46  	destinationPortItems    []RuleItem
    47  	allItems                []RuleItem
    48  	invert                  bool
    49  	outbound                string
    50  	disableCache            bool
    51  }
    52  
    53  func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) {
    54  	rule := &DefaultDNSRule{
    55  		invert:       options.Invert,
    56  		outbound:     options.Server,
    57  		disableCache: options.DisableCache,
    58  	}
    59  	if len(options.Inbound) > 0 {
    60  		item := NewInboundRule(options.Inbound)
    61  		rule.items = append(rule.items, item)
    62  		rule.allItems = append(rule.allItems, item)
    63  	}
    64  	if options.IPVersion > 0 {
    65  		switch options.IPVersion {
    66  		case 4, 6:
    67  			item := NewIPVersionItem(options.IPVersion == 6)
    68  			rule.items = append(rule.items, item)
    69  			rule.allItems = append(rule.allItems, item)
    70  		default:
    71  			return nil, E.New("invalid ip version: ", options.IPVersion)
    72  		}
    73  	}
    74  	if len(options.QueryType) > 0 {
    75  		item := NewQueryTypeItem(options.QueryType)
    76  		rule.items = append(rule.items, item)
    77  		rule.allItems = append(rule.allItems, item)
    78  	}
    79  	if options.Network != "" {
    80  		switch options.Network {
    81  		case N.NetworkTCP, N.NetworkUDP:
    82  			item := NewNetworkItem(options.Network)
    83  			rule.items = append(rule.items, item)
    84  			rule.allItems = append(rule.allItems, item)
    85  		default:
    86  			return nil, E.New("invalid network: ", options.Network)
    87  		}
    88  	}
    89  	if len(options.AuthUser) > 0 {
    90  		item := NewAuthUserItem(options.AuthUser)
    91  		rule.items = append(rule.items, item)
    92  		rule.allItems = append(rule.allItems, item)
    93  	}
    94  	if len(options.Protocol) > 0 {
    95  		item := NewProtocolItem(options.Protocol)
    96  		rule.items = append(rule.items, item)
    97  		rule.allItems = append(rule.allItems, item)
    98  	}
    99  	if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 {
   100  		item := NewDomainItem(options.Domain, options.DomainSuffix)
   101  		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
   102  		rule.allItems = append(rule.allItems, item)
   103  	}
   104  	if len(options.DomainKeyword) > 0 {
   105  		item := NewDomainKeywordItem(options.DomainKeyword)
   106  		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
   107  		rule.allItems = append(rule.allItems, item)
   108  	}
   109  	if len(options.DomainRegex) > 0 {
   110  		item, err := NewDomainRegexItem(options.DomainRegex)
   111  		if err != nil {
   112  			return nil, E.Cause(err, "domain_regex")
   113  		}
   114  		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
   115  		rule.allItems = append(rule.allItems, item)
   116  	}
   117  	if len(options.Geosite) > 0 {
   118  		item := NewGeositeItem(router, logger, options.Geosite)
   119  		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
   120  		rule.allItems = append(rule.allItems, item)
   121  	}
   122  	if len(options.SourceGeoIP) > 0 {
   123  		item := NewGeoIPItem(router, logger, true, options.SourceGeoIP)
   124  		rule.sourceAddressItems = append(rule.sourceAddressItems, item)
   125  		rule.allItems = append(rule.allItems, item)
   126  	}
   127  	if len(options.SourceIPCIDR) > 0 {
   128  		item, err := NewIPCIDRItem(true, options.SourceIPCIDR)
   129  		if err != nil {
   130  			return nil, E.Cause(err, "source_ipcidr")
   131  		}
   132  		rule.sourceAddressItems = append(rule.sourceAddressItems, item)
   133  		rule.allItems = append(rule.allItems, item)
   134  	}
   135  	if len(options.SourcePort) > 0 {
   136  		item := NewPortItem(true, options.SourcePort)
   137  		rule.sourcePortItems = append(rule.sourcePortItems, item)
   138  		rule.allItems = append(rule.allItems, item)
   139  	}
   140  	if len(options.SourcePortRange) > 0 {
   141  		item, err := NewPortRangeItem(true, options.SourcePortRange)
   142  		if err != nil {
   143  			return nil, E.Cause(err, "source_port_range")
   144  		}
   145  		rule.sourcePortItems = append(rule.sourcePortItems, item)
   146  		rule.allItems = append(rule.allItems, item)
   147  	}
   148  	if len(options.Port) > 0 {
   149  		item := NewPortItem(false, options.Port)
   150  		rule.destinationPortItems = append(rule.destinationPortItems, item)
   151  		rule.allItems = append(rule.allItems, item)
   152  	}
   153  	if len(options.PortRange) > 0 {
   154  		item, err := NewPortRangeItem(false, options.PortRange)
   155  		if err != nil {
   156  			return nil, E.Cause(err, "port_range")
   157  		}
   158  		rule.destinationPortItems = append(rule.destinationPortItems, item)
   159  		rule.allItems = append(rule.allItems, item)
   160  	}
   161  	if len(options.ProcessName) > 0 {
   162  		item := NewProcessItem(options.ProcessName)
   163  		rule.items = append(rule.items, item)
   164  		rule.allItems = append(rule.allItems, item)
   165  	}
   166  	if len(options.ProcessPath) > 0 {
   167  		item := NewProcessPathItem(options.ProcessPath)
   168  		rule.items = append(rule.items, item)
   169  		rule.allItems = append(rule.allItems, item)
   170  	}
   171  	if len(options.PackageName) > 0 {
   172  		item := NewPackageNameItem(options.PackageName)
   173  		rule.items = append(rule.items, item)
   174  		rule.allItems = append(rule.allItems, item)
   175  	}
   176  	if len(options.User) > 0 {
   177  		item := NewUserItem(options.User)
   178  		rule.items = append(rule.items, item)
   179  		rule.allItems = append(rule.allItems, item)
   180  	}
   181  	if len(options.UserID) > 0 {
   182  		item := NewUserIDItem(options.UserID)
   183  		rule.items = append(rule.items, item)
   184  		rule.allItems = append(rule.allItems, item)
   185  	}
   186  	if len(options.Outbound) > 0 {
   187  		item := NewOutboundRule(options.Outbound)
   188  		rule.items = append(rule.items, item)
   189  		rule.allItems = append(rule.allItems, item)
   190  	}
   191  	if options.ClashMode != "" {
   192  		item := NewClashModeItem(router, options.ClashMode)
   193  		rule.items = append(rule.items, item)
   194  		rule.allItems = append(rule.allItems, item)
   195  	}
   196  	return rule, nil
   197  }
   198  
   199  func (r *DefaultDNSRule) Type() string {
   200  	return C.RuleTypeDefault
   201  }
   202  
   203  func (r *DefaultDNSRule) Start() error {
   204  	for _, item := range r.allItems {
   205  		err := common.Start(item)
   206  		if err != nil {
   207  			return err
   208  		}
   209  	}
   210  	return nil
   211  }
   212  
   213  func (r *DefaultDNSRule) Close() error {
   214  	for _, item := range r.allItems {
   215  		err := common.Close(item)
   216  		if err != nil {
   217  			return err
   218  		}
   219  	}
   220  	return nil
   221  }
   222  
   223  func (r *DefaultDNSRule) UpdateGeosite() error {
   224  	for _, item := range r.allItems {
   225  		if geositeItem, isSite := item.(*GeositeItem); isSite {
   226  			err := geositeItem.Update()
   227  			if err != nil {
   228  				return err
   229  			}
   230  		}
   231  	}
   232  	return nil
   233  }
   234  
   235  func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool {
   236  	for _, item := range r.items {
   237  		if !item.Match(metadata) {
   238  			return r.invert
   239  		}
   240  	}
   241  
   242  	if len(r.sourceAddressItems) > 0 {
   243  		var sourceAddressMatch bool
   244  		for _, item := range r.sourceAddressItems {
   245  			if item.Match(metadata) {
   246  				sourceAddressMatch = true
   247  				break
   248  			}
   249  		}
   250  		if !sourceAddressMatch {
   251  			return r.invert
   252  		}
   253  	}
   254  
   255  	if len(r.sourcePortItems) > 0 {
   256  		var sourcePortMatch bool
   257  		for _, item := range r.sourcePortItems {
   258  			if item.Match(metadata) {
   259  				sourcePortMatch = true
   260  				break
   261  			}
   262  		}
   263  		if !sourcePortMatch {
   264  			return r.invert
   265  		}
   266  	}
   267  
   268  	if len(r.destinationAddressItems) > 0 {
   269  		var destinationAddressMatch bool
   270  		for _, item := range r.destinationAddressItems {
   271  			if item.Match(metadata) {
   272  				destinationAddressMatch = true
   273  				break
   274  			}
   275  		}
   276  		if !destinationAddressMatch {
   277  			return r.invert
   278  		}
   279  	}
   280  
   281  	if len(r.destinationPortItems) > 0 {
   282  		var destinationPortMatch bool
   283  		for _, item := range r.destinationPortItems {
   284  			if item.Match(metadata) {
   285  				destinationPortMatch = true
   286  				break
   287  			}
   288  		}
   289  		if !destinationPortMatch {
   290  			return r.invert
   291  		}
   292  	}
   293  
   294  	return !r.invert
   295  }
   296  
   297  func (r *DefaultDNSRule) Outbound() string {
   298  	return r.outbound
   299  }
   300  
   301  func (r *DefaultDNSRule) DisableCache() bool {
   302  	return r.disableCache
   303  }
   304  
   305  func (r *DefaultDNSRule) String() string {
   306  	return strings.Join(F.MapToString(r.allItems), " ")
   307  }
   308  
   309  var _ adapter.DNSRule = (*LogicalDNSRule)(nil)
   310  
   311  type LogicalDNSRule struct {
   312  	mode         string
   313  	rules        []*DefaultDNSRule
   314  	invert       bool
   315  	outbound     string
   316  	disableCache bool
   317  }
   318  
   319  func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) {
   320  	r := &LogicalDNSRule{
   321  		rules:        make([]*DefaultDNSRule, len(options.Rules)),
   322  		invert:       options.Invert,
   323  		outbound:     options.Server,
   324  		disableCache: options.DisableCache,
   325  	}
   326  	switch options.Mode {
   327  	case C.LogicalTypeAnd:
   328  		r.mode = C.LogicalTypeAnd
   329  	case C.LogicalTypeOr:
   330  		r.mode = C.LogicalTypeOr
   331  	default:
   332  		return nil, E.New("unknown logical mode: ", options.Mode)
   333  	}
   334  	for i, subRule := range options.Rules {
   335  		rule, err := NewDefaultDNSRule(router, logger, subRule)
   336  		if err != nil {
   337  			return nil, E.Cause(err, "sub rule[", i, "]")
   338  		}
   339  		r.rules[i] = rule
   340  	}
   341  	return r, nil
   342  }
   343  
   344  func (r *LogicalDNSRule) Type() string {
   345  	return C.RuleTypeLogical
   346  }
   347  
   348  func (r *LogicalDNSRule) UpdateGeosite() error {
   349  	for _, rule := range r.rules {
   350  		err := rule.UpdateGeosite()
   351  		if err != nil {
   352  			return err
   353  		}
   354  	}
   355  	return nil
   356  }
   357  
   358  func (r *LogicalDNSRule) Start() error {
   359  	for _, rule := range r.rules {
   360  		err := rule.Start()
   361  		if err != nil {
   362  			return err
   363  		}
   364  	}
   365  	return nil
   366  }
   367  
   368  func (r *LogicalDNSRule) Close() error {
   369  	for _, rule := range r.rules {
   370  		err := rule.Close()
   371  		if err != nil {
   372  			return err
   373  		}
   374  	}
   375  	return nil
   376  }
   377  
   378  func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
   379  	if r.mode == C.LogicalTypeAnd {
   380  		return common.All(r.rules, func(it *DefaultDNSRule) bool {
   381  			return it.Match(metadata)
   382  		}) != r.invert
   383  	} else {
   384  		return common.Any(r.rules, func(it *DefaultDNSRule) bool {
   385  			return it.Match(metadata)
   386  		}) != r.invert
   387  	}
   388  }
   389  
   390  func (r *LogicalDNSRule) Outbound() string {
   391  	return r.outbound
   392  }
   393  
   394  func (r *LogicalDNSRule) DisableCache() bool {
   395  	return r.disableCache
   396  }
   397  
   398  func (r *LogicalDNSRule) String() string {
   399  	var op string
   400  	switch r.mode {
   401  	case C.LogicalTypeAnd:
   402  		op = "&&"
   403  	case C.LogicalTypeOr:
   404  		op = "||"
   405  	}
   406  	if !r.invert {
   407  		return strings.Join(F.MapToString(r.rules), " "+op+" ")
   408  	} else {
   409  		return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
   410  	}
   411  }