github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/common/strmatcher/matchergroup_mph.go (about)

     1  package strmatcher
     2  
     3  import (
     4  	"math/bits"
     5  	"sort"
     6  	"strings"
     7  	"unsafe"
     8  )
     9  
    10  // PrimeRK is the prime base used in Rabin-Karp algorithm.
    11  const PrimeRK = 16777619
    12  
    13  // RollingHash calculates the rolling murmurHash of given string based on a provided suffix hash.
    14  func RollingHash(hash uint32, input string) uint32 {
    15  	for i := len(input) - 1; i >= 0; i-- {
    16  		hash = hash*PrimeRK + uint32(input[i])
    17  	}
    18  	return hash
    19  }
    20  
    21  // MemHash is the hash function used by go map, it utilizes available hardware instructions(behaves
    22  // as aeshash if aes instruction is available).
    23  // With different seed, each MemHash<seed> performs as distinct hash functions.
    24  func MemHash(seed uint32, input string) uint32 {
    25  	return uint32(strhash(unsafe.Pointer(&input), uintptr(seed))) // nosemgrep
    26  }
    27  
    28  const (
    29  	mphMatchTypeCount = 2 // Full and Domain
    30  )
    31  
    32  type mphRuleInfo struct {
    33  	rollingHash uint32
    34  	matchers    [mphMatchTypeCount][]uint32
    35  }
    36  
    37  // MphMatcherGroup is an implementation of MatcherGroup.
    38  // It implements Rabin-Karp algorithm and minimal perfect hash table for Full and Domain matcher.
    39  type MphMatcherGroup struct {
    40  	rules      []string   // RuleIdx -> pattern string, index 0 reserved for failed lookup
    41  	values     [][]uint32 // RuleIdx -> registered matcher values for the pattern (Full Matcher takes precedence)
    42  	level0     []uint32   // RollingHash & Mask -> seed for Memhash
    43  	level0Mask uint32     // Mask restricting RollingHash to 0 ~ len(level0)
    44  	level1     []uint32   // Memhash<seed> & Mask -> stored index for rules
    45  	level1Mask uint32     // Mask for restricting Memhash<seed> to 0 ~ len(level1)
    46  	ruleInfos  *map[string]mphRuleInfo
    47  }
    48  
    49  func NewMphMatcherGroup() *MphMatcherGroup {
    50  	return &MphMatcherGroup{
    51  		rules:      []string{""},
    52  		values:     [][]uint32{nil},
    53  		level0:     nil,
    54  		level0Mask: 0,
    55  		level1:     nil,
    56  		level1Mask: 0,
    57  		ruleInfos:  &map[string]mphRuleInfo{}, // Only used for building, destroyed after build complete
    58  	}
    59  }
    60  
    61  // AddFullMatcher implements MatcherGroupForFull.
    62  func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) {
    63  	pattern := strings.ToLower(matcher.Pattern())
    64  	g.addPattern(0, "", pattern, matcher.Type(), value)
    65  }
    66  
    67  // AddDomainMatcher implements MatcherGroupForDomain.
    68  func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) {
    69  	pattern := strings.ToLower(matcher.Pattern())
    70  	hash := g.addPattern(0, "", pattern, matcher.Type(), value) // For full domain match
    71  	g.addPattern(hash, pattern, ".", matcher.Type(), value)     // For partial domain match
    72  }
    73  
    74  func (g *MphMatcherGroup) addPattern(suffixHash uint32, suffixPattern string, pattern string, matcherType Type, value uint32) uint32 {
    75  	fullPattern := pattern + suffixPattern
    76  	info, found := (*g.ruleInfos)[fullPattern]
    77  	if !found {
    78  		info = mphRuleInfo{rollingHash: RollingHash(suffixHash, pattern)}
    79  		g.rules = append(g.rules, fullPattern)
    80  		g.values = append(g.values, nil)
    81  	}
    82  	info.matchers[matcherType] = append(info.matchers[matcherType], value)
    83  	(*g.ruleInfos)[fullPattern] = info
    84  	return info.rollingHash
    85  }
    86  
    87  // Build builds a minimal perfect hash table for insert rules.
    88  // Algorithm used: Hash, displace, and compress. See http://cmph.sourceforge.net/papers/esa09.pdf
    89  func (g *MphMatcherGroup) Build() error {
    90  	ruleCount := len(*g.ruleInfos)
    91  	g.level0 = make([]uint32, nextPow2(ruleCount/4))
    92  	g.level0Mask = uint32(len(g.level0) - 1)
    93  	g.level1 = make([]uint32, nextPow2(ruleCount))
    94  	g.level1Mask = uint32(len(g.level1) - 1)
    95  
    96  	// Create buckets based on all rule's rolling hash
    97  	buckets := make([][]uint32, len(g.level0))
    98  	for ruleIdx := 1; ruleIdx < len(g.rules); ruleIdx++ { // Traverse rules starting from index 1 (0 reserved for failed lookup)
    99  		ruleInfo := (*g.ruleInfos)[g.rules[ruleIdx]]
   100  		bucketIdx := ruleInfo.rollingHash & g.level0Mask
   101  		buckets[bucketIdx] = append(buckets[bucketIdx], uint32(ruleIdx))
   102  		g.values[ruleIdx] = append(ruleInfo.matchers[Full], ruleInfo.matchers[Domain]...) // nolint:gocritic
   103  	}
   104  	g.ruleInfos = nil // Set ruleInfos nil to release memory
   105  
   106  	// Sort buckets in descending order with respect to each bucket's size
   107  	bucketIdxs := make([]int, len(buckets))
   108  	for bucketIdx := range buckets {
   109  		bucketIdxs[bucketIdx] = bucketIdx
   110  	}
   111  	sort.Slice(bucketIdxs, func(i, j int) bool { return len(buckets[bucketIdxs[i]]) > len(buckets[bucketIdxs[j]]) })
   112  
   113  	// Exercise Hash, Displace, and Compress algorithm to construct minimal perfect hash table
   114  	occupied := make([]bool, len(g.level1)) // Whether a second-level hash has been already used
   115  	hashedBucket := make([]uint32, 0, 4)    // Second-level hashes for each rule in a specific bucket
   116  	for _, bucketIdx := range bucketIdxs {
   117  		bucket := buckets[bucketIdx]
   118  		hashedBucket = hashedBucket[:0]
   119  		seed := uint32(0)
   120  		for len(hashedBucket) != len(bucket) {
   121  			for _, ruleIdx := range bucket {
   122  				memHash := MemHash(seed, g.rules[ruleIdx]) & g.level1Mask
   123  				if occupied[memHash] { // Collision occurred with this seed
   124  					for _, hash := range hashedBucket { // Revert all values in this hashed bucket
   125  						occupied[hash] = false
   126  						g.level1[hash] = 0
   127  					}
   128  					hashedBucket = hashedBucket[:0]
   129  					seed++ // Try next seed
   130  					break
   131  				}
   132  				occupied[memHash] = true
   133  				g.level1[memHash] = ruleIdx // The final value in the hash table
   134  				hashedBucket = append(hashedBucket, memHash)
   135  			}
   136  		}
   137  		g.level0[bucketIdx] = seed // Displacement value for this bucket
   138  	}
   139  	return nil
   140  }
   141  
   142  // Lookup searches for input in minimal perfect hash table and returns its index. 0 indicates not found.
   143  func (g *MphMatcherGroup) Lookup(rollingHash uint32, input string) uint32 {
   144  	i0 := rollingHash & g.level0Mask
   145  	seed := g.level0[i0]
   146  	i1 := MemHash(seed, input) & g.level1Mask
   147  	if n := g.level1[i1]; g.rules[n] == input {
   148  		return n
   149  	}
   150  	return 0
   151  }
   152  
   153  // Match implements MatcherGroup.Match.
   154  func (g *MphMatcherGroup) Match(input string) []uint32 {
   155  	matches := make([][]uint32, 0, 5)
   156  	hash := uint32(0)
   157  	for i := len(input) - 1; i >= 0; i-- {
   158  		hash = hash*PrimeRK + uint32(input[i])
   159  		if input[i] == '.' {
   160  			if mphIdx := g.Lookup(hash, input[i:]); mphIdx != 0 {
   161  				matches = append(matches, g.values[mphIdx])
   162  			}
   163  		}
   164  	}
   165  	if mphIdx := g.Lookup(hash, input); mphIdx != 0 {
   166  		matches = append(matches, g.values[mphIdx])
   167  	}
   168  	return CompositeMatchesReverse(matches)
   169  }
   170  
   171  // MatchAny implements MatcherGroup.MatchAny.
   172  func (g *MphMatcherGroup) MatchAny(input string) bool {
   173  	hash := uint32(0)
   174  	for i := len(input) - 1; i >= 0; i-- {
   175  		hash = hash*PrimeRK + uint32(input[i])
   176  		if input[i] == '.' {
   177  			if g.Lookup(hash, input[i:]) != 0 {
   178  				return true
   179  			}
   180  		}
   181  	}
   182  	return g.Lookup(hash, input) != 0
   183  }
   184  
   185  func nextPow2(v int) int {
   186  	if v <= 1 {
   187  		return 1
   188  	}
   189  	const MaxUInt = ^uint(0)
   190  	n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1
   191  	return int(n)
   192  }
   193  
   194  //go:noescape
   195  //go:linkname strhash runtime.strhash
   196  func strhash(p unsafe.Pointer, h uintptr) uintptr