github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/common/strmatcher/matchergroup_mph.go (about) 1 package strmatcher 2 3 import ( 4 "math/bits" 5 "sort" 6 "strings" 7 "unsafe" 8 ) 9 10 // PrimeRK is the prime base used in Rabin-Karp algorithm. 11 const PrimeRK = 16777619 12 13 // RollingHash calculates the rolling murmurHash of given string based on a provided suffix hash. 14 func RollingHash(hash uint32, input string) uint32 { 15 for i := len(input) - 1; i >= 0; i-- { 16 hash = hash*PrimeRK + uint32(input[i]) 17 } 18 return hash 19 } 20 21 // MemHash is the hash function used by go map, it utilizes available hardware instructions(behaves 22 // as aeshash if aes instruction is available). 23 // With different seed, each MemHash<seed> performs as distinct hash functions. 24 func MemHash(seed uint32, input string) uint32 { 25 return uint32(strhash(unsafe.Pointer(&input), uintptr(seed))) // nosemgrep 26 } 27 28 const ( 29 mphMatchTypeCount = 2 // Full and Domain 30 ) 31 32 type mphRuleInfo struct { 33 rollingHash uint32 34 matchers [mphMatchTypeCount][]uint32 35 } 36 37 // MphMatcherGroup is an implementation of MatcherGroup. 38 // It implements Rabin-Karp algorithm and minimal perfect hash table for Full and Domain matcher. 39 type MphMatcherGroup struct { 40 rules []string // RuleIdx -> pattern string, index 0 reserved for failed lookup 41 values [][]uint32 // RuleIdx -> registered matcher values for the pattern (Full Matcher takes precedence) 42 level0 []uint32 // RollingHash & Mask -> seed for Memhash 43 level0Mask uint32 // Mask restricting RollingHash to 0 ~ len(level0) 44 level1 []uint32 // Memhash<seed> & Mask -> stored index for rules 45 level1Mask uint32 // Mask for restricting Memhash<seed> to 0 ~ len(level1) 46 ruleInfos *map[string]mphRuleInfo 47 } 48 49 func NewMphMatcherGroup() *MphMatcherGroup { 50 return &MphMatcherGroup{ 51 rules: []string{""}, 52 values: [][]uint32{nil}, 53 level0: nil, 54 level0Mask: 0, 55 level1: nil, 56 level1Mask: 0, 57 ruleInfos: &map[string]mphRuleInfo{}, // Only used for building, destroyed after build complete 58 } 59 } 60 61 // AddFullMatcher implements MatcherGroupForFull. 62 func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) { 63 pattern := strings.ToLower(matcher.Pattern()) 64 g.addPattern(0, "", pattern, matcher.Type(), value) 65 } 66 67 // AddDomainMatcher implements MatcherGroupForDomain. 68 func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) { 69 pattern := strings.ToLower(matcher.Pattern()) 70 hash := g.addPattern(0, "", pattern, matcher.Type(), value) // For full domain match 71 g.addPattern(hash, pattern, ".", matcher.Type(), value) // For partial domain match 72 } 73 74 func (g *MphMatcherGroup) addPattern(suffixHash uint32, suffixPattern string, pattern string, matcherType Type, value uint32) uint32 { 75 fullPattern := pattern + suffixPattern 76 info, found := (*g.ruleInfos)[fullPattern] 77 if !found { 78 info = mphRuleInfo{rollingHash: RollingHash(suffixHash, pattern)} 79 g.rules = append(g.rules, fullPattern) 80 g.values = append(g.values, nil) 81 } 82 info.matchers[matcherType] = append(info.matchers[matcherType], value) 83 (*g.ruleInfos)[fullPattern] = info 84 return info.rollingHash 85 } 86 87 // Build builds a minimal perfect hash table for insert rules. 88 // Algorithm used: Hash, displace, and compress. See http://cmph.sourceforge.net/papers/esa09.pdf 89 func (g *MphMatcherGroup) Build() error { 90 ruleCount := len(*g.ruleInfos) 91 g.level0 = make([]uint32, nextPow2(ruleCount/4)) 92 g.level0Mask = uint32(len(g.level0) - 1) 93 g.level1 = make([]uint32, nextPow2(ruleCount)) 94 g.level1Mask = uint32(len(g.level1) - 1) 95 96 // Create buckets based on all rule's rolling hash 97 buckets := make([][]uint32, len(g.level0)) 98 for ruleIdx := 1; ruleIdx < len(g.rules); ruleIdx++ { // Traverse rules starting from index 1 (0 reserved for failed lookup) 99 ruleInfo := (*g.ruleInfos)[g.rules[ruleIdx]] 100 bucketIdx := ruleInfo.rollingHash & g.level0Mask 101 buckets[bucketIdx] = append(buckets[bucketIdx], uint32(ruleIdx)) 102 g.values[ruleIdx] = append(ruleInfo.matchers[Full], ruleInfo.matchers[Domain]...) // nolint:gocritic 103 } 104 g.ruleInfos = nil // Set ruleInfos nil to release memory 105 106 // Sort buckets in descending order with respect to each bucket's size 107 bucketIdxs := make([]int, len(buckets)) 108 for bucketIdx := range buckets { 109 bucketIdxs[bucketIdx] = bucketIdx 110 } 111 sort.Slice(bucketIdxs, func(i, j int) bool { return len(buckets[bucketIdxs[i]]) > len(buckets[bucketIdxs[j]]) }) 112 113 // Exercise Hash, Displace, and Compress algorithm to construct minimal perfect hash table 114 occupied := make([]bool, len(g.level1)) // Whether a second-level hash has been already used 115 hashedBucket := make([]uint32, 0, 4) // Second-level hashes for each rule in a specific bucket 116 for _, bucketIdx := range bucketIdxs { 117 bucket := buckets[bucketIdx] 118 hashedBucket = hashedBucket[:0] 119 seed := uint32(0) 120 for len(hashedBucket) != len(bucket) { 121 for _, ruleIdx := range bucket { 122 memHash := MemHash(seed, g.rules[ruleIdx]) & g.level1Mask 123 if occupied[memHash] { // Collision occurred with this seed 124 for _, hash := range hashedBucket { // Revert all values in this hashed bucket 125 occupied[hash] = false 126 g.level1[hash] = 0 127 } 128 hashedBucket = hashedBucket[:0] 129 seed++ // Try next seed 130 break 131 } 132 occupied[memHash] = true 133 g.level1[memHash] = ruleIdx // The final value in the hash table 134 hashedBucket = append(hashedBucket, memHash) 135 } 136 } 137 g.level0[bucketIdx] = seed // Displacement value for this bucket 138 } 139 return nil 140 } 141 142 // Lookup searches for input in minimal perfect hash table and returns its index. 0 indicates not found. 143 func (g *MphMatcherGroup) Lookup(rollingHash uint32, input string) uint32 { 144 i0 := rollingHash & g.level0Mask 145 seed := g.level0[i0] 146 i1 := MemHash(seed, input) & g.level1Mask 147 if n := g.level1[i1]; g.rules[n] == input { 148 return n 149 } 150 return 0 151 } 152 153 // Match implements MatcherGroup.Match. 154 func (g *MphMatcherGroup) Match(input string) []uint32 { 155 matches := make([][]uint32, 0, 5) 156 hash := uint32(0) 157 for i := len(input) - 1; i >= 0; i-- { 158 hash = hash*PrimeRK + uint32(input[i]) 159 if input[i] == '.' { 160 if mphIdx := g.Lookup(hash, input[i:]); mphIdx != 0 { 161 matches = append(matches, g.values[mphIdx]) 162 } 163 } 164 } 165 if mphIdx := g.Lookup(hash, input); mphIdx != 0 { 166 matches = append(matches, g.values[mphIdx]) 167 } 168 return CompositeMatchesReverse(matches) 169 } 170 171 // MatchAny implements MatcherGroup.MatchAny. 172 func (g *MphMatcherGroup) MatchAny(input string) bool { 173 hash := uint32(0) 174 for i := len(input) - 1; i >= 0; i-- { 175 hash = hash*PrimeRK + uint32(input[i]) 176 if input[i] == '.' { 177 if g.Lookup(hash, input[i:]) != 0 { 178 return true 179 } 180 } 181 } 182 return g.Lookup(hash, input) != 0 183 } 184 185 func nextPow2(v int) int { 186 if v <= 1 { 187 return 1 188 } 189 const MaxUInt = ^uint(0) 190 n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1 191 return int(n) 192 } 193 194 //go:noescape 195 //go:linkname strhash runtime.strhash 196 func strhash(p unsafe.Pointer, h uintptr) uintptr