github.com/xraypb/xray-core@v1.6.6/common/strmatcher/mph_matcher.go (about)

     1  package strmatcher
     2  
     3  import (
     4  	"math/bits"
     5  	"regexp"
     6  	"sort"
     7  	"strings"
     8  	"unsafe"
     9  )
    10  
    11  // PrimeRK is the prime base used in Rabin-Karp algorithm.
    12  const PrimeRK = 16777619
    13  
    14  // calculate the rolling murmurHash of given string
    15  func RollingHash(s string) uint32 {
    16  	h := uint32(0)
    17  	for i := len(s) - 1; i >= 0; i-- {
    18  		h = h*PrimeRK + uint32(s[i])
    19  	}
    20  	return h
    21  }
    22  
    23  // A MphMatcherGroup is divided into three parts:
    24  // 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table;
    25  // 2. `substr` patterns are matched by ac automaton;
    26  // 3. `regex` patterns are matched with the regex library.
    27  type MphMatcherGroup struct {
    28  	ac            *ACAutomaton
    29  	otherMatchers []matcherEntry
    30  	rules         []string
    31  	level0        []uint32
    32  	level0Mask    int
    33  	level1        []uint32
    34  	level1Mask    int
    35  	count         uint32
    36  	ruleMap       *map[string]uint32
    37  }
    38  
    39  func (g *MphMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) {
    40  	h := RollingHash(pattern)
    41  	switch t {
    42  	case Domain:
    43  		(*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.')
    44  		fallthrough
    45  	case Full:
    46  		(*g.ruleMap)[pattern] = h
    47  	default:
    48  	}
    49  }
    50  
    51  func NewMphMatcherGroup() *MphMatcherGroup {
    52  	return &MphMatcherGroup{
    53  		ac:            nil,
    54  		otherMatchers: nil,
    55  		rules:         nil,
    56  		level0:        nil,
    57  		level0Mask:    0,
    58  		level1:        nil,
    59  		level1Mask:    0,
    60  		count:         1,
    61  		ruleMap:       &map[string]uint32{},
    62  	}
    63  }
    64  
    65  // AddPattern adds a pattern to MphMatcherGroup
    66  func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) {
    67  	switch t {
    68  	case Substr:
    69  		if g.ac == nil {
    70  			g.ac = NewACAutomaton()
    71  		}
    72  		g.ac.Add(pattern, t)
    73  	case Full, Domain:
    74  		pattern = strings.ToLower(pattern)
    75  		g.AddFullOrDomainPattern(pattern, t)
    76  	case Regex:
    77  		r, err := regexp.Compile(pattern)
    78  		if err != nil {
    79  			return 0, err
    80  		}
    81  		g.otherMatchers = append(g.otherMatchers, matcherEntry{
    82  			m:  &regexMatcher{pattern: r},
    83  			id: g.count,
    84  		})
    85  	default:
    86  		panic("Unknown type")
    87  	}
    88  	return g.count, nil
    89  }
    90  
    91  // Build builds a minimal perfect hash table and ac automaton from insert rules
    92  func (g *MphMatcherGroup) Build() {
    93  	if g.ac != nil {
    94  		g.ac.Build()
    95  	}
    96  	keyLen := len(*g.ruleMap)
    97  	if keyLen == 0 {
    98  		keyLen = 1
    99  		(*g.ruleMap)["empty___"] = RollingHash("empty___")
   100  	}
   101  	g.level0 = make([]uint32, nextPow2(keyLen/4))
   102  	g.level0Mask = len(g.level0) - 1
   103  	g.level1 = make([]uint32, nextPow2(keyLen))
   104  	g.level1Mask = len(g.level1) - 1
   105  	sparseBuckets := make([][]int, len(g.level0))
   106  	var ruleIdx int
   107  	for rule, hash := range *g.ruleMap {
   108  		n := int(hash) & g.level0Mask
   109  		g.rules = append(g.rules, rule)
   110  		sparseBuckets[n] = append(sparseBuckets[n], ruleIdx)
   111  		ruleIdx++
   112  	}
   113  	g.ruleMap = nil
   114  	var buckets []indexBucket
   115  	for n, vals := range sparseBuckets {
   116  		if len(vals) > 0 {
   117  			buckets = append(buckets, indexBucket{n, vals})
   118  		}
   119  	}
   120  	sort.Sort(bySize(buckets))
   121  
   122  	occ := make([]bool, len(g.level1))
   123  	var tmpOcc []int
   124  	for _, bucket := range buckets {
   125  		seed := uint32(0)
   126  		for {
   127  			findSeed := true
   128  			tmpOcc = tmpOcc[:0]
   129  			for _, i := range bucket.vals {
   130  				n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask
   131  				if occ[n] {
   132  					for _, n := range tmpOcc {
   133  						occ[n] = false
   134  					}
   135  					seed++
   136  					findSeed = false
   137  					break
   138  				}
   139  				occ[n] = true
   140  				tmpOcc = append(tmpOcc, n)
   141  				g.level1[n] = uint32(i)
   142  			}
   143  			if findSeed {
   144  				g.level0[bucket.n] = seed
   145  				break
   146  			}
   147  		}
   148  	}
   149  }
   150  
   151  func nextPow2(v int) int {
   152  	if v <= 1 {
   153  		return 1
   154  	}
   155  	const MaxUInt = ^uint(0)
   156  	n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1
   157  	return int(n)
   158  }
   159  
   160  // Lookup searches for s in t and returns its index and whether it was found.
   161  func (g *MphMatcherGroup) Lookup(h uint32, s string) bool {
   162  	i0 := int(h) & g.level0Mask
   163  	seed := g.level0[i0]
   164  	i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask
   165  	n := g.level1[i1]
   166  	return s == g.rules[int(n)]
   167  }
   168  
   169  // Match implements IndexMatcher.Match.
   170  func (g *MphMatcherGroup) Match(pattern string) []uint32 {
   171  	result := []uint32{}
   172  	hash := uint32(0)
   173  	for i := len(pattern) - 1; i >= 0; i-- {
   174  		hash = hash*PrimeRK + uint32(pattern[i])
   175  		if pattern[i] == '.' {
   176  			if g.Lookup(hash, pattern[i:]) {
   177  				result = append(result, 1)
   178  				return result
   179  			}
   180  		}
   181  	}
   182  	if g.Lookup(hash, pattern) {
   183  		result = append(result, 1)
   184  		return result
   185  	}
   186  	if g.ac != nil && g.ac.Match(pattern) {
   187  		result = append(result, 1)
   188  		return result
   189  	}
   190  	for _, e := range g.otherMatchers {
   191  		if e.m.Match(pattern) {
   192  			result = append(result, e.id)
   193  			return result
   194  		}
   195  	}
   196  	return nil
   197  }
   198  
   199  type indexBucket struct {
   200  	n    int
   201  	vals []int
   202  }
   203  
   204  type bySize []indexBucket
   205  
   206  func (s bySize) Len() int           { return len(s) }
   207  func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) }
   208  func (s bySize) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
   209  
   210  type stringStruct struct {
   211  	str unsafe.Pointer
   212  	len int
   213  }
   214  
   215  func strhashFallback(a unsafe.Pointer, h uintptr) uintptr {
   216  	x := (*stringStruct)(a)
   217  	return memhashFallback(x.str, h, uintptr(x.len))
   218  }
   219  
   220  const (
   221  	// Constants for multiplication: four random odd 64-bit numbers.
   222  	m1 = 16877499708836156737
   223  	m2 = 2820277070424839065
   224  	m3 = 9497967016996688599
   225  	m4 = 15839092249703872147
   226  )
   227  
   228  var hashkey = [4]uintptr{1, 1, 1, 1}
   229  
   230  func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr {
   231  	h := uint64(seed + s*hashkey[0])
   232  tail:
   233  	switch {
   234  	case s == 0:
   235  	case s < 4:
   236  		h ^= uint64(*(*byte)(p))
   237  		h ^= uint64(*(*byte)(add(p, s>>1))) << 8
   238  		h ^= uint64(*(*byte)(add(p, s-1))) << 16
   239  		h = rotl31(h*m1) * m2
   240  	case s <= 8:
   241  		h ^= uint64(readUnaligned32(p))
   242  		h ^= uint64(readUnaligned32(add(p, s-4))) << 32
   243  		h = rotl31(h*m1) * m2
   244  	case s <= 16:
   245  		h ^= readUnaligned64(p)
   246  		h = rotl31(h*m1) * m2
   247  		h ^= readUnaligned64(add(p, s-8))
   248  		h = rotl31(h*m1) * m2
   249  	case s <= 32:
   250  		h ^= readUnaligned64(p)
   251  		h = rotl31(h*m1) * m2
   252  		h ^= readUnaligned64(add(p, 8))
   253  		h = rotl31(h*m1) * m2
   254  		h ^= readUnaligned64(add(p, s-16))
   255  		h = rotl31(h*m1) * m2
   256  		h ^= readUnaligned64(add(p, s-8))
   257  		h = rotl31(h*m1) * m2
   258  	default:
   259  		v1 := h
   260  		v2 := uint64(seed * hashkey[1])
   261  		v3 := uint64(seed * hashkey[2])
   262  		v4 := uint64(seed * hashkey[3])
   263  		for s >= 32 {
   264  			v1 ^= readUnaligned64(p)
   265  			v1 = rotl31(v1*m1) * m2
   266  			p = add(p, 8)
   267  			v2 ^= readUnaligned64(p)
   268  			v2 = rotl31(v2*m2) * m3
   269  			p = add(p, 8)
   270  			v3 ^= readUnaligned64(p)
   271  			v3 = rotl31(v3*m3) * m4
   272  			p = add(p, 8)
   273  			v4 ^= readUnaligned64(p)
   274  			v4 = rotl31(v4*m4) * m1
   275  			p = add(p, 8)
   276  			s -= 32
   277  		}
   278  		h = v1 ^ v2 ^ v3 ^ v4
   279  		goto tail
   280  	}
   281  
   282  	h ^= h >> 29
   283  	h *= m3
   284  	h ^= h >> 32
   285  	return uintptr(h)
   286  }
   287  
   288  func add(p unsafe.Pointer, x uintptr) unsafe.Pointer {
   289  	return unsafe.Pointer(uintptr(p) + x)
   290  }
   291  
   292  func readUnaligned32(p unsafe.Pointer) uint32 {
   293  	q := (*[4]byte)(p)
   294  	return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24
   295  }
   296  
   297  func rotl31(x uint64) uint64 {
   298  	return (x << 31) | (x >> (64 - 31))
   299  }
   300  
   301  func readUnaligned64(p unsafe.Pointer) uint64 {
   302  	q := (*[8]byte)(p)
   303  	return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56
   304  }