github.com/sagernet/sing-box@v1.9.0-rc.20/common/srs/binary.go (about)

     1  package srs
     2  
     3  import (
     4  	"compress/zlib"
     5  	"encoding/binary"
     6  	"io"
     7  	"net/netip"
     8  
     9  	C "github.com/sagernet/sing-box/constant"
    10  	"github.com/sagernet/sing-box/option"
    11  	"github.com/sagernet/sing/common"
    12  	"github.com/sagernet/sing/common/domain"
    13  	E "github.com/sagernet/sing/common/exceptions"
    14  	"github.com/sagernet/sing/common/rw"
    15  
    16  	"go4.org/netipx"
    17  )
    18  
    19  var MagicBytes = [3]byte{0x53, 0x52, 0x53} // SRS
    20  
    21  const (
    22  	ruleItemQueryType uint8 = iota
    23  	ruleItemNetwork
    24  	ruleItemDomain
    25  	ruleItemDomainKeyword
    26  	ruleItemDomainRegex
    27  	ruleItemSourceIPCIDR
    28  	ruleItemIPCIDR
    29  	ruleItemSourcePort
    30  	ruleItemSourcePortRange
    31  	ruleItemPort
    32  	ruleItemPortRange
    33  	ruleItemProcessName
    34  	ruleItemProcessPath
    35  	ruleItemPackageName
    36  	ruleItemWIFISSID
    37  	ruleItemWIFIBSSID
    38  	ruleItemFinal uint8 = 0xFF
    39  )
    40  
    41  func Read(reader io.Reader, recovery bool) (ruleSet option.PlainRuleSet, err error) {
    42  	var magicBytes [3]byte
    43  	_, err = io.ReadFull(reader, magicBytes[:])
    44  	if err != nil {
    45  		return
    46  	}
    47  	if magicBytes != MagicBytes {
    48  		err = E.New("invalid sing-box rule set file")
    49  		return
    50  	}
    51  	var version uint8
    52  	err = binary.Read(reader, binary.BigEndian, &version)
    53  	if err != nil {
    54  		return ruleSet, err
    55  	}
    56  	if version != 1 {
    57  		return ruleSet, E.New("unsupported version: ", version)
    58  	}
    59  	zReader, err := zlib.NewReader(reader)
    60  	if err != nil {
    61  		return
    62  	}
    63  	length, err := rw.ReadUVariant(zReader)
    64  	if err != nil {
    65  		return
    66  	}
    67  	ruleSet.Rules = make([]option.HeadlessRule, length)
    68  	for i := uint64(0); i < length; i++ {
    69  		ruleSet.Rules[i], err = readRule(zReader, recovery)
    70  		if err != nil {
    71  			err = E.Cause(err, "read rule[", i, "]")
    72  			return
    73  		}
    74  	}
    75  	return
    76  }
    77  
    78  func Write(writer io.Writer, ruleSet option.PlainRuleSet) error {
    79  	_, err := writer.Write(MagicBytes[:])
    80  	if err != nil {
    81  		return err
    82  	}
    83  	err = binary.Write(writer, binary.BigEndian, uint8(1))
    84  	if err != nil {
    85  		return err
    86  	}
    87  	zWriter, err := zlib.NewWriterLevel(writer, zlib.BestCompression)
    88  	if err != nil {
    89  		return err
    90  	}
    91  	err = rw.WriteUVariant(zWriter, uint64(len(ruleSet.Rules)))
    92  	if err != nil {
    93  		return err
    94  	}
    95  	for _, rule := range ruleSet.Rules {
    96  		err = writeRule(zWriter, rule)
    97  		if err != nil {
    98  			return err
    99  		}
   100  	}
   101  	return zWriter.Close()
   102  }
   103  
   104  func readRule(reader io.Reader, recovery bool) (rule option.HeadlessRule, err error) {
   105  	var ruleType uint8
   106  	err = binary.Read(reader, binary.BigEndian, &ruleType)
   107  	if err != nil {
   108  		return
   109  	}
   110  	switch ruleType {
   111  	case 0:
   112  		rule.Type = C.RuleTypeDefault
   113  		rule.DefaultOptions, err = readDefaultRule(reader, recovery)
   114  	case 1:
   115  		rule.Type = C.RuleTypeLogical
   116  		rule.LogicalOptions, err = readLogicalRule(reader, recovery)
   117  	default:
   118  		err = E.New("unknown rule type: ", ruleType)
   119  	}
   120  	return
   121  }
   122  
   123  func writeRule(writer io.Writer, rule option.HeadlessRule) error {
   124  	switch rule.Type {
   125  	case C.RuleTypeDefault:
   126  		return writeDefaultRule(writer, rule.DefaultOptions)
   127  	case C.RuleTypeLogical:
   128  		return writeLogicalRule(writer, rule.LogicalOptions)
   129  	default:
   130  		panic("unknown rule type: " + rule.Type)
   131  	}
   132  }
   133  
   134  func readDefaultRule(reader io.Reader, recovery bool) (rule option.DefaultHeadlessRule, err error) {
   135  	var lastItemType uint8
   136  	for {
   137  		var itemType uint8
   138  		err = binary.Read(reader, binary.BigEndian, &itemType)
   139  		if err != nil {
   140  			return
   141  		}
   142  		switch itemType {
   143  		case ruleItemQueryType:
   144  			var rawQueryType []uint16
   145  			rawQueryType, err = readRuleItemUint16(reader)
   146  			if err != nil {
   147  				return
   148  			}
   149  			rule.QueryType = common.Map(rawQueryType, func(it uint16) option.DNSQueryType {
   150  				return option.DNSQueryType(it)
   151  			})
   152  		case ruleItemNetwork:
   153  			rule.Network, err = readRuleItemString(reader)
   154  		case ruleItemDomain:
   155  			var matcher *domain.Matcher
   156  			matcher, err = domain.ReadMatcher(reader)
   157  			if err != nil {
   158  				return
   159  			}
   160  			rule.DomainMatcher = matcher
   161  		case ruleItemDomainKeyword:
   162  			rule.DomainKeyword, err = readRuleItemString(reader)
   163  		case ruleItemDomainRegex:
   164  			rule.DomainRegex, err = readRuleItemString(reader)
   165  		case ruleItemSourceIPCIDR:
   166  			rule.SourceIPSet, err = readIPSet(reader)
   167  			if err != nil {
   168  				return
   169  			}
   170  			if recovery {
   171  				rule.SourceIPCIDR = common.Map(rule.SourceIPSet.Prefixes(), netip.Prefix.String)
   172  			}
   173  		case ruleItemIPCIDR:
   174  			rule.IPSet, err = readIPSet(reader)
   175  			if err != nil {
   176  				return
   177  			}
   178  			if recovery {
   179  				rule.IPCIDR = common.Map(rule.IPSet.Prefixes(), netip.Prefix.String)
   180  			}
   181  		case ruleItemSourcePort:
   182  			rule.SourcePort, err = readRuleItemUint16(reader)
   183  		case ruleItemSourcePortRange:
   184  			rule.SourcePortRange, err = readRuleItemString(reader)
   185  		case ruleItemPort:
   186  			rule.Port, err = readRuleItemUint16(reader)
   187  		case ruleItemPortRange:
   188  			rule.PortRange, err = readRuleItemString(reader)
   189  		case ruleItemProcessName:
   190  			rule.ProcessName, err = readRuleItemString(reader)
   191  		case ruleItemProcessPath:
   192  			rule.ProcessPath, err = readRuleItemString(reader)
   193  		case ruleItemPackageName:
   194  			rule.PackageName, err = readRuleItemString(reader)
   195  		case ruleItemWIFISSID:
   196  			rule.WIFISSID, err = readRuleItemString(reader)
   197  		case ruleItemWIFIBSSID:
   198  			rule.WIFIBSSID, err = readRuleItemString(reader)
   199  		case ruleItemFinal:
   200  			err = binary.Read(reader, binary.BigEndian, &rule.Invert)
   201  			return
   202  		default:
   203  			err = E.New("unknown rule item type: ", itemType, ", last type: ", lastItemType)
   204  		}
   205  		if err != nil {
   206  			return
   207  		}
   208  		lastItemType = itemType
   209  	}
   210  }
   211  
   212  func writeDefaultRule(writer io.Writer, rule option.DefaultHeadlessRule) error {
   213  	err := binary.Write(writer, binary.BigEndian, uint8(0))
   214  	if err != nil {
   215  		return err
   216  	}
   217  	if len(rule.QueryType) > 0 {
   218  		err = writeRuleItemUint16(writer, ruleItemQueryType, common.Map(rule.QueryType, func(it option.DNSQueryType) uint16 {
   219  			return uint16(it)
   220  		}))
   221  		if err != nil {
   222  			return err
   223  		}
   224  	}
   225  	if len(rule.Network) > 0 {
   226  		err = writeRuleItemString(writer, ruleItemNetwork, rule.Network)
   227  		if err != nil {
   228  			return err
   229  		}
   230  	}
   231  	if len(rule.Domain) > 0 || len(rule.DomainSuffix) > 0 {
   232  		err = binary.Write(writer, binary.BigEndian, ruleItemDomain)
   233  		if err != nil {
   234  			return err
   235  		}
   236  		err = domain.NewMatcher(rule.Domain, rule.DomainSuffix).Write(writer)
   237  		if err != nil {
   238  			return err
   239  		}
   240  	}
   241  	if len(rule.DomainKeyword) > 0 {
   242  		err = writeRuleItemString(writer, ruleItemDomainKeyword, rule.DomainKeyword)
   243  		if err != nil {
   244  			return err
   245  		}
   246  	}
   247  	if len(rule.DomainRegex) > 0 {
   248  		err = writeRuleItemString(writer, ruleItemDomainRegex, rule.DomainRegex)
   249  		if err != nil {
   250  			return err
   251  		}
   252  	}
   253  	if len(rule.SourceIPCIDR) > 0 {
   254  		err = writeRuleItemCIDR(writer, ruleItemSourceIPCIDR, rule.SourceIPCIDR)
   255  		if err != nil {
   256  			return E.Cause(err, "source_ip_cidr")
   257  		}
   258  	}
   259  	if len(rule.IPCIDR) > 0 {
   260  		err = writeRuleItemCIDR(writer, ruleItemIPCIDR, rule.IPCIDR)
   261  		if err != nil {
   262  			return E.Cause(err, "ipcidr")
   263  		}
   264  	}
   265  	if len(rule.SourcePort) > 0 {
   266  		err = writeRuleItemUint16(writer, ruleItemSourcePort, rule.SourcePort)
   267  		if err != nil {
   268  			return err
   269  		}
   270  	}
   271  	if len(rule.SourcePortRange) > 0 {
   272  		err = writeRuleItemString(writer, ruleItemSourcePortRange, rule.SourcePortRange)
   273  		if err != nil {
   274  			return err
   275  		}
   276  	}
   277  	if len(rule.Port) > 0 {
   278  		err = writeRuleItemUint16(writer, ruleItemPort, rule.Port)
   279  		if err != nil {
   280  			return err
   281  		}
   282  	}
   283  	if len(rule.PortRange) > 0 {
   284  		err = writeRuleItemString(writer, ruleItemPortRange, rule.PortRange)
   285  		if err != nil {
   286  			return err
   287  		}
   288  	}
   289  	if len(rule.ProcessName) > 0 {
   290  		err = writeRuleItemString(writer, ruleItemProcessName, rule.ProcessName)
   291  		if err != nil {
   292  			return err
   293  		}
   294  	}
   295  	if len(rule.ProcessPath) > 0 {
   296  		err = writeRuleItemString(writer, ruleItemProcessPath, rule.ProcessPath)
   297  		if err != nil {
   298  			return err
   299  		}
   300  	}
   301  	if len(rule.PackageName) > 0 {
   302  		err = writeRuleItemString(writer, ruleItemPackageName, rule.PackageName)
   303  		if err != nil {
   304  			return err
   305  		}
   306  	}
   307  	if len(rule.WIFISSID) > 0 {
   308  		err = writeRuleItemString(writer, ruleItemWIFISSID, rule.WIFISSID)
   309  		if err != nil {
   310  			return err
   311  		}
   312  	}
   313  	if len(rule.WIFIBSSID) > 0 {
   314  		err = writeRuleItemString(writer, ruleItemWIFIBSSID, rule.WIFIBSSID)
   315  		if err != nil {
   316  			return err
   317  		}
   318  	}
   319  	err = binary.Write(writer, binary.BigEndian, ruleItemFinal)
   320  	if err != nil {
   321  		return err
   322  	}
   323  	err = binary.Write(writer, binary.BigEndian, rule.Invert)
   324  	if err != nil {
   325  		return err
   326  	}
   327  	return nil
   328  }
   329  
   330  func readRuleItemString(reader io.Reader) ([]string, error) {
   331  	length, err := rw.ReadUVariant(reader)
   332  	if err != nil {
   333  		return nil, err
   334  	}
   335  	value := make([]string, length)
   336  	for i := uint64(0); i < length; i++ {
   337  		value[i], err = rw.ReadVString(reader)
   338  		if err != nil {
   339  			return nil, err
   340  		}
   341  	}
   342  	return value, nil
   343  }
   344  
   345  func writeRuleItemString(writer io.Writer, itemType uint8, value []string) error {
   346  	err := binary.Write(writer, binary.BigEndian, itemType)
   347  	if err != nil {
   348  		return err
   349  	}
   350  	err = rw.WriteUVariant(writer, uint64(len(value)))
   351  	if err != nil {
   352  		return err
   353  	}
   354  	for _, item := range value {
   355  		err = rw.WriteVString(writer, item)
   356  		if err != nil {
   357  			return err
   358  		}
   359  	}
   360  	return nil
   361  }
   362  
   363  func readRuleItemUint16(reader io.Reader) ([]uint16, error) {
   364  	length, err := rw.ReadUVariant(reader)
   365  	if err != nil {
   366  		return nil, err
   367  	}
   368  	value := make([]uint16, length)
   369  	for i := uint64(0); i < length; i++ {
   370  		err = binary.Read(reader, binary.BigEndian, &value[i])
   371  		if err != nil {
   372  			return nil, err
   373  		}
   374  	}
   375  	return value, nil
   376  }
   377  
   378  func writeRuleItemUint16(writer io.Writer, itemType uint8, value []uint16) error {
   379  	err := binary.Write(writer, binary.BigEndian, itemType)
   380  	if err != nil {
   381  		return err
   382  	}
   383  	err = rw.WriteUVariant(writer, uint64(len(value)))
   384  	if err != nil {
   385  		return err
   386  	}
   387  	for _, item := range value {
   388  		err = binary.Write(writer, binary.BigEndian, item)
   389  		if err != nil {
   390  			return err
   391  		}
   392  	}
   393  	return nil
   394  }
   395  
   396  func writeRuleItemCIDR(writer io.Writer, itemType uint8, value []string) error {
   397  	var builder netipx.IPSetBuilder
   398  	for i, prefixString := range value {
   399  		prefix, err := netip.ParsePrefix(prefixString)
   400  		if err == nil {
   401  			builder.AddPrefix(prefix)
   402  			continue
   403  		}
   404  		addr, addrErr := netip.ParseAddr(prefixString)
   405  		if addrErr == nil {
   406  			builder.Add(addr)
   407  			continue
   408  		}
   409  		return E.Cause(err, "parse [", i, "]")
   410  	}
   411  	ipSet, err := builder.IPSet()
   412  	if err != nil {
   413  		return err
   414  	}
   415  	err = binary.Write(writer, binary.BigEndian, itemType)
   416  	if err != nil {
   417  		return err
   418  	}
   419  	return writeIPSet(writer, ipSet)
   420  }
   421  
   422  func readLogicalRule(reader io.Reader, recovery bool) (logicalRule option.LogicalHeadlessRule, err error) {
   423  	var mode uint8
   424  	err = binary.Read(reader, binary.BigEndian, &mode)
   425  	if err != nil {
   426  		return
   427  	}
   428  	switch mode {
   429  	case 0:
   430  		logicalRule.Mode = C.LogicalTypeAnd
   431  	case 1:
   432  		logicalRule.Mode = C.LogicalTypeOr
   433  	default:
   434  		err = E.New("unknown logical mode: ", mode)
   435  		return
   436  	}
   437  	length, err := rw.ReadUVariant(reader)
   438  	if err != nil {
   439  		return
   440  	}
   441  	logicalRule.Rules = make([]option.HeadlessRule, length)
   442  	for i := uint64(0); i < length; i++ {
   443  		logicalRule.Rules[i], err = readRule(reader, recovery)
   444  		if err != nil {
   445  			err = E.Cause(err, "read logical rule [", i, "]")
   446  			return
   447  		}
   448  	}
   449  	err = binary.Read(reader, binary.BigEndian, &logicalRule.Invert)
   450  	if err != nil {
   451  		return
   452  	}
   453  	return
   454  }
   455  
   456  func writeLogicalRule(writer io.Writer, logicalRule option.LogicalHeadlessRule) error {
   457  	err := binary.Write(writer, binary.BigEndian, uint8(1))
   458  	if err != nil {
   459  		return err
   460  	}
   461  	switch logicalRule.Mode {
   462  	case C.LogicalTypeAnd:
   463  		err = binary.Write(writer, binary.BigEndian, uint8(0))
   464  	case C.LogicalTypeOr:
   465  		err = binary.Write(writer, binary.BigEndian, uint8(1))
   466  	default:
   467  		panic("unknown logical mode: " + logicalRule.Mode)
   468  	}
   469  	if err != nil {
   470  		return err
   471  	}
   472  	err = rw.WriteUVariant(writer, uint64(len(logicalRule.Rules)))
   473  	if err != nil {
   474  		return err
   475  	}
   476  	for _, rule := range logicalRule.Rules {
   477  		err = writeRule(writer, rule)
   478  		if err != nil {
   479  			return err
   480  		}
   481  	}
   482  	err = binary.Write(writer, binary.BigEndian, logicalRule.Invert)
   483  	if err != nil {
   484  		return err
   485  	}
   486  	return nil
   487  }