github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/common/strmatcher/mph_matcher.go (about) 1 package strmatcher 2 3 import ( 4 "math/bits" 5 "regexp" 6 "sort" 7 "strings" 8 "unsafe" 9 ) 10 11 // PrimeRK is the prime base used in Rabin-Karp algorithm. 12 const PrimeRK = 16777619 13 14 // calculate the rolling murmurHash of given string 15 func RollingHash(s string) uint32 { 16 h := uint32(0) 17 for i := len(s) - 1; i >= 0; i-- { 18 h = h*PrimeRK + uint32(s[i]) 19 } 20 return h 21 } 22 23 // A MphMatcherGroup is divided into three parts: 24 // 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table; 25 // 2. `substr` patterns are matched by ac automaton; 26 // 3. `regex` patterns are matched with the regex library. 27 type MphMatcherGroup struct { 28 ac *ACAutomaton 29 otherMatchers []matcherEntry 30 rules []string 31 level0 []uint32 32 level0Mask int 33 level1 []uint32 34 level1Mask int 35 count uint32 36 ruleMap *map[string]uint32 37 } 38 39 func (g *MphMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) { 40 h := RollingHash(pattern) 41 switch t { 42 case Domain: 43 (*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.') 44 fallthrough 45 case Full: 46 (*g.ruleMap)[pattern] = h 47 default: 48 } 49 } 50 51 func NewMphMatcherGroup() *MphMatcherGroup { 52 return &MphMatcherGroup{ 53 ac: nil, 54 otherMatchers: nil, 55 rules: nil, 56 level0: nil, 57 level0Mask: 0, 58 level1: nil, 59 level1Mask: 0, 60 count: 1, 61 ruleMap: &map[string]uint32{}, 62 } 63 } 64 65 // AddPattern adds a pattern to MphMatcherGroup 66 func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) { 67 switch t { 68 case Substr: 69 if g.ac == nil { 70 g.ac = NewACAutomaton() 71 } 72 g.ac.Add(pattern, t) 73 case Full, Domain: 74 pattern = strings.ToLower(pattern) 75 g.AddFullOrDomainPattern(pattern, t) 76 case Regex: 77 r, err := regexp.Compile(pattern) 78 if err != nil { 79 return 0, err 80 } 81 g.otherMatchers = append(g.otherMatchers, matcherEntry{ 82 m: ®exMatcher{pattern: r}, 83 id: g.count, 84 }) 85 default: 86 panic("Unknown type") 87 } 88 return g.count, nil 89 } 90 91 // Build builds a minimal perfect hash table and ac automaton from insert rules 92 func (g *MphMatcherGroup) Build() { 93 if g.ac != nil { 94 g.ac.Build() 95 } 96 keyLen := len(*g.ruleMap) 97 if keyLen == 0 { 98 keyLen = 1 99 (*g.ruleMap)["empty___"] = RollingHash("empty___") 100 } 101 g.level0 = make([]uint32, nextPow2(keyLen/4)) 102 g.level0Mask = len(g.level0) - 1 103 g.level1 = make([]uint32, nextPow2(keyLen)) 104 g.level1Mask = len(g.level1) - 1 105 sparseBuckets := make([][]int, len(g.level0)) 106 var ruleIdx int 107 for rule, hash := range *g.ruleMap { 108 n := int(hash) & g.level0Mask 109 g.rules = append(g.rules, rule) 110 sparseBuckets[n] = append(sparseBuckets[n], ruleIdx) 111 ruleIdx++ 112 } 113 g.ruleMap = nil 114 var buckets []indexBucket 115 for n, vals := range sparseBuckets { 116 if len(vals) > 0 { 117 buckets = append(buckets, indexBucket{n, vals}) 118 } 119 } 120 sort.Sort(bySize(buckets)) 121 122 occ := make([]bool, len(g.level1)) 123 var tmpOcc []int 124 for _, bucket := range buckets { 125 seed := uint32(0) 126 for { 127 findSeed := true 128 tmpOcc = tmpOcc[:0] 129 for _, i := range bucket.vals { 130 n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask 131 if occ[n] { 132 for _, n := range tmpOcc { 133 occ[n] = false 134 } 135 seed++ 136 findSeed = false 137 break 138 } 139 occ[n] = true 140 tmpOcc = append(tmpOcc, n) 141 g.level1[n] = uint32(i) 142 } 143 if findSeed { 144 g.level0[bucket.n] = seed 145 break 146 } 147 } 148 } 149 } 150 151 func nextPow2(v int) int { 152 if v <= 1 { 153 return 1 154 } 155 const MaxUInt = ^uint(0) 156 n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1 157 return int(n) 158 } 159 160 // Lookup searches for s in t and returns its index and whether it was found. 161 func (g *MphMatcherGroup) Lookup(h uint32, s string) bool { 162 i0 := int(h) & g.level0Mask 163 seed := g.level0[i0] 164 i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask 165 n := g.level1[i1] 166 return s == g.rules[int(n)] 167 } 168 169 // Match implements IndexMatcher.Match. 170 func (g *MphMatcherGroup) Match(pattern string) []uint32 { 171 result := []uint32{} 172 hash := uint32(0) 173 for i := len(pattern) - 1; i >= 0; i-- { 174 hash = hash*PrimeRK + uint32(pattern[i]) 175 if pattern[i] == '.' { 176 if g.Lookup(hash, pattern[i:]) { 177 result = append(result, 1) 178 return result 179 } 180 } 181 } 182 if g.Lookup(hash, pattern) { 183 result = append(result, 1) 184 return result 185 } 186 if g.ac != nil && g.ac.Match(pattern) { 187 result = append(result, 1) 188 return result 189 } 190 for _, e := range g.otherMatchers { 191 if e.m.Match(pattern) { 192 result = append(result, e.id) 193 return result 194 } 195 } 196 return nil 197 } 198 199 type indexBucket struct { 200 n int 201 vals []int 202 } 203 204 type bySize []indexBucket 205 206 func (s bySize) Len() int { return len(s) } 207 func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) } 208 func (s bySize) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 209 210 type stringStruct struct { 211 str unsafe.Pointer 212 len int 213 } 214 215 func strhashFallback(a unsafe.Pointer, h uintptr) uintptr { 216 x := (*stringStruct)(a) 217 return memhashFallback(x.str, h, uintptr(x.len)) 218 } 219 220 const ( 221 // Constants for multiplication: four random odd 64-bit numbers. 222 m1 = 16877499708836156737 223 m2 = 2820277070424839065 224 m3 = 9497967016996688599 225 m4 = 15839092249703872147 226 ) 227 228 var hashkey = [4]uintptr{1, 1, 1, 1} 229 230 func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr { 231 h := uint64(seed + s*hashkey[0]) 232 tail: 233 switch { 234 case s == 0: 235 case s < 4: 236 h ^= uint64(*(*byte)(p)) 237 h ^= uint64(*(*byte)(add(p, s>>1))) << 8 238 h ^= uint64(*(*byte)(add(p, s-1))) << 16 239 h = rotl31(h*m1) * m2 240 case s <= 8: 241 h ^= uint64(readUnaligned32(p)) 242 h ^= uint64(readUnaligned32(add(p, s-4))) << 32 243 h = rotl31(h*m1) * m2 244 case s <= 16: 245 h ^= readUnaligned64(p) 246 h = rotl31(h*m1) * m2 247 h ^= readUnaligned64(add(p, s-8)) 248 h = rotl31(h*m1) * m2 249 case s <= 32: 250 h ^= readUnaligned64(p) 251 h = rotl31(h*m1) * m2 252 h ^= readUnaligned64(add(p, 8)) 253 h = rotl31(h*m1) * m2 254 h ^= readUnaligned64(add(p, s-16)) 255 h = rotl31(h*m1) * m2 256 h ^= readUnaligned64(add(p, s-8)) 257 h = rotl31(h*m1) * m2 258 default: 259 v1 := h 260 v2 := uint64(seed * hashkey[1]) 261 v3 := uint64(seed * hashkey[2]) 262 v4 := uint64(seed * hashkey[3]) 263 for s >= 32 { 264 v1 ^= readUnaligned64(p) 265 v1 = rotl31(v1*m1) * m2 266 p = add(p, 8) 267 v2 ^= readUnaligned64(p) 268 v2 = rotl31(v2*m2) * m3 269 p = add(p, 8) 270 v3 ^= readUnaligned64(p) 271 v3 = rotl31(v3*m3) * m4 272 p = add(p, 8) 273 v4 ^= readUnaligned64(p) 274 v4 = rotl31(v4*m4) * m1 275 p = add(p, 8) 276 s -= 32 277 } 278 h = v1 ^ v2 ^ v3 ^ v4 279 goto tail 280 } 281 282 h ^= h >> 29 283 h *= m3 284 h ^= h >> 32 285 return uintptr(h) 286 } 287 288 func add(p unsafe.Pointer, x uintptr) unsafe.Pointer { 289 return unsafe.Pointer(uintptr(p) + x) 290 } 291 292 func readUnaligned32(p unsafe.Pointer) uint32 { 293 q := (*[4]byte)(p) 294 return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24 295 } 296 297 func rotl31(x uint64) uint64 { 298 return (x << 31) | (x >> (64 - 31)) 299 } 300 301 func readUnaligned64(p unsafe.Pointer) uint64 { 302 q := (*[8]byte)(p) 303 return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56 304 }