github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/domain/matcher.go (about)

     1  package domain
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  	"sort"
     7  	"unicode/utf8"
     8  
     9  	"github.com/sagernet/sing/common/rw"
    10  )
    11  
    12  type Matcher struct {
    13  	set *succinctSet
    14  }
    15  
    16  func NewMatcher(domains []string, domainSuffix []string) *Matcher {
    17  	domainList := make([]string, 0, len(domains)+2*len(domainSuffix))
    18  	seen := make(map[string]bool, len(domainList))
    19  	for _, domain := range domainSuffix {
    20  		if seen[domain] {
    21  			continue
    22  		}
    23  		seen[domain] = true
    24  		if domain[0] == '.' {
    25  			domainList = append(domainList, reverseDomainSuffix(domain))
    26  		} else {
    27  			domainList = append(domainList, reverseDomain(domain))
    28  			domainList = append(domainList, reverseRootDomainSuffix(domain))
    29  		}
    30  	}
    31  	for _, domain := range domains {
    32  		if seen[domain] {
    33  			continue
    34  		}
    35  		seen[domain] = true
    36  		domainList = append(domainList, reverseDomain(domain))
    37  	}
    38  	sort.Strings(domainList)
    39  	return &Matcher{newSuccinctSet(domainList)}
    40  }
    41  
    42  func ReadMatcher(reader io.Reader) (*Matcher, error) {
    43  	var version uint8
    44  	err := binary.Read(reader, binary.BigEndian, &version)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  	leavesLength, err := rw.ReadUVariant(reader)
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  	leaves := make([]uint64, leavesLength)
    53  	err = binary.Read(reader, binary.BigEndian, leaves)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	labelBitmapLength, err := rw.ReadUVariant(reader)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	labelBitmap := make([]uint64, labelBitmapLength)
    62  	err = binary.Read(reader, binary.BigEndian, labelBitmap)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	labelsLength, err := rw.ReadUVariant(reader)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	labels := make([]byte, labelsLength)
    71  	_, err = io.ReadFull(reader, labels)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	set := &succinctSet{
    76  		leaves:      leaves,
    77  		labelBitmap: labelBitmap,
    78  		labels:      labels,
    79  	}
    80  	set.init()
    81  	return &Matcher{set}, nil
    82  }
    83  
    84  func (m *Matcher) Match(domain string) bool {
    85  	return m.set.Has(reverseDomain(domain))
    86  }
    87  
    88  func (m *Matcher) Write(writer io.Writer) error {
    89  	err := binary.Write(writer, binary.BigEndian, byte(1))
    90  	if err != nil {
    91  		return err
    92  	}
    93  	err = rw.WriteUVariant(writer, uint64(len(m.set.leaves)))
    94  	if err != nil {
    95  		return err
    96  	}
    97  	err = binary.Write(writer, binary.BigEndian, m.set.leaves)
    98  	if err != nil {
    99  		return err
   100  	}
   101  	err = rw.WriteUVariant(writer, uint64(len(m.set.labelBitmap)))
   102  	if err != nil {
   103  		return err
   104  	}
   105  	err = binary.Write(writer, binary.BigEndian, m.set.labelBitmap)
   106  	if err != nil {
   107  		return err
   108  	}
   109  	err = rw.WriteUVariant(writer, uint64(len(m.set.labels)))
   110  	if err != nil {
   111  		return err
   112  	}
   113  	_, err = writer.Write(m.set.labels)
   114  	if err != nil {
   115  		return err
   116  	}
   117  	return nil
   118  }
   119  
   120  func reverseDomain(domain string) string {
   121  	l := len(domain)
   122  	b := make([]byte, l)
   123  	for i := 0; i < l; {
   124  		r, n := utf8.DecodeRuneInString(domain[i:])
   125  		i += n
   126  		utf8.EncodeRune(b[l-i:], r)
   127  	}
   128  	return string(b)
   129  }
   130  
   131  func reverseDomainSuffix(domain string) string {
   132  	l := len(domain)
   133  	b := make([]byte, l+1)
   134  	for i := 0; i < l; {
   135  		r, n := utf8.DecodeRuneInString(domain[i:])
   136  		i += n
   137  		utf8.EncodeRune(b[l-i:], r)
   138  	}
   139  	b[l] = prefixLabel
   140  	return string(b)
   141  }
   142  
   143  func reverseRootDomainSuffix(domain string) string {
   144  	l := len(domain)
   145  	b := make([]byte, l+2)
   146  	for i := 0; i < l; {
   147  		r, n := utf8.DecodeRuneInString(domain[i:])
   148  		i += n
   149  		utf8.EncodeRune(b[l-i:], r)
   150  	}
   151  	b[l] = '.'
   152  	b[l+1] = prefixLabel
   153  	return string(b)
   154  }