github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/internal/fuzzy/symbol.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fuzzy
     6  
     7  import (
     8  	"unicode"
     9  )
    10  
    11  // SymbolMatcher implements a fuzzy matching algorithm optimized for Go symbols
    12  // of the form:
    13  //
    14  //	example.com/path/to/package.object.field
    15  //
    16  // Knowing that we are matching symbols like this allows us to make the
    17  // following optimizations:
    18  //   - We can incorporate right-to-left relevance directly into the score
    19  //     calculation.
    20  //   - We can match from right to left, discarding leading bytes if the input is
    21  //     too long.
    22  //   - We just take the right-most match without losing too much precision. This
    23  //     allows us to use an O(n) algorithm.
    24  //   - We can operate directly on chunked strings; in many cases we will
    25  //     be storing the package path and/or package name separately from the
    26  //     symbol or identifiers, so doing this avoids allocating strings.
    27  //   - We can return the index of the right-most match, allowing us to trim
    28  //     irrelevant qualification.
    29  //
    30  // This implementation is experimental, serving as a reference fast algorithm
    31  // to compare to the fuzzy algorithm implemented by Matcher.
    32  type SymbolMatcher struct {
    33  	// Using buffers of length 256 is both a reasonable size for most qualified
    34  	// symbols, and makes it easy to avoid bounds checks by using uint8 indexes.
    35  	pattern     [256]rune
    36  	patternLen  uint8
    37  	inputBuffer [256]rune   // avoid allocating when considering chunks
    38  	roles       [256]uint32 // which roles does a rune play (word start, etc.)
    39  	segments    [256]uint8  // how many segments from the right is each rune
    40  }
    41  
    42  const (
    43  	segmentStart uint32 = 1 << iota
    44  	wordStart
    45  	separator
    46  )
    47  
    48  // NewSymbolMatcher creates a SymbolMatcher that may be used to match the given
    49  // search pattern.
    50  //
    51  // Currently this matcher only accepts case-insensitive fuzzy patterns.
    52  //
    53  // An empty pattern matches no input.
    54  func NewSymbolMatcher(pattern string) *SymbolMatcher {
    55  	m := &SymbolMatcher{}
    56  	for _, p := range pattern {
    57  		m.pattern[m.patternLen] = unicode.ToLower(p)
    58  		m.patternLen++
    59  		if m.patternLen == 255 || int(m.patternLen) == len(pattern) {
    60  			// break at 255 so that we can represent patternLen with a uint8.
    61  			break
    62  		}
    63  	}
    64  	return m
    65  }
    66  
    67  // Match looks for the right-most match of the search pattern within the symbol
    68  // represented by concatenating the given chunks, returning its offset and
    69  // score.
    70  //
    71  // If a match is found, the first return value will hold the absolute byte
    72  // offset within all chunks for the start of the symbol. In other words, the
    73  // index of the match within strings.Join(chunks, ""). If no match is found,
    74  // the first return value will be -1.
    75  //
    76  // The second return value will be the score of the match, which is always
    77  // between 0 and 1, inclusive. A score of 0 indicates no match.
    78  func (m *SymbolMatcher) Match(chunks []string) (int, float64) {
    79  	// Explicit behavior for an empty pattern.
    80  	//
    81  	// As a minor optimization, this also avoids nilness checks later on, since
    82  	// the compiler can prove that m != nil.
    83  	if m.patternLen == 0 {
    84  		return -1, 0
    85  	}
    86  
    87  	// First phase: populate the input buffer with lower-cased runes.
    88  	//
    89  	// We could also check for a forward match here, but since we'd have to write
    90  	// the entire input anyway this has negligible impact on performance.
    91  
    92  	var (
    93  		inputLen  = uint8(0)
    94  		modifiers = wordStart | segmentStart
    95  	)
    96  
    97  input:
    98  	for _, chunk := range chunks {
    99  		for _, r := range chunk {
   100  			if r == '.' || r == '/' {
   101  				modifiers |= separator
   102  			}
   103  			// optimization: avoid calls to unicode.ToLower, which can't be inlined.
   104  			l := r
   105  			if r <= unicode.MaxASCII {
   106  				if 'A' <= r && r <= 'Z' {
   107  					l = r + 'a' - 'A'
   108  				}
   109  			} else {
   110  				l = unicode.ToLower(r)
   111  			}
   112  			if l != r {
   113  				modifiers |= wordStart
   114  			}
   115  			m.inputBuffer[inputLen] = l
   116  			m.roles[inputLen] = modifiers
   117  			inputLen++
   118  			if m.roles[inputLen-1]&separator != 0 {
   119  				modifiers = wordStart | segmentStart
   120  			} else {
   121  				modifiers = 0
   122  			}
   123  			// TODO: we should prefer the right-most input if it overflows, rather
   124  			//       than the left-most as we're doing here.
   125  			if inputLen == 255 {
   126  				break input
   127  			}
   128  		}
   129  	}
   130  
   131  	// Second phase: find the right-most match, and count segments from the
   132  	// right.
   133  
   134  	var (
   135  		pi    = uint8(m.patternLen - 1) // pattern index
   136  		p     = m.pattern[pi]           // pattern rune
   137  		start = -1                      // start offset of match
   138  		rseg  = uint8(0)
   139  	)
   140  	const maxSeg = 3 // maximum number of segments from the right to count, for scoring purposes.
   141  
   142  	for ii := inputLen - 1; ; ii-- {
   143  		r := m.inputBuffer[ii]
   144  		if rseg < maxSeg && m.roles[ii]&separator != 0 {
   145  			rseg++
   146  		}
   147  		m.segments[ii] = rseg
   148  		if p == r {
   149  			if pi == 0 {
   150  				start = int(ii)
   151  				break
   152  			}
   153  			pi--
   154  			p = m.pattern[pi]
   155  		}
   156  		// Don't check ii >= 0 in the loop condition: ii is a uint8.
   157  		if ii == 0 {
   158  			break
   159  		}
   160  	}
   161  
   162  	if start < 0 {
   163  		// no match: skip scoring
   164  		return -1, 0
   165  	}
   166  
   167  	// Third phase: find the shortest match, and compute the score.
   168  
   169  	// Score is the average score for each character.
   170  	//
   171  	// A character score is the multiple of:
   172  	//   1. 1.0 if the character starts a segment, .8 if the character start a
   173  	//      mid-segment word, otherwise 0.6. This carries over to immediately
   174  	//      following characters.
   175  	//   2. For the final character match, the multiplier from (1) is reduced to
   176  	//     .8 if the next character in the input is a mid-segment word, or 0.6 if
   177  	//      the next character in the input is not a word or segment start. This
   178  	//      ensures that we favor whole-word or whole-segment matches over prefix
   179  	//      matches.
   180  	//   3. 1.0 if the character is part of the last segment, otherwise
   181  	//      1.0-.2*<segments from the right>, with a max segment count of 3.
   182  	//
   183  	// This is a very naive algorithm, but it is fast. There's lots of prior art
   184  	// here, and we should leverage it. For example, we could explicitly consider
   185  	// character distance, and exact matches of words or segments.
   186  	//
   187  	// Also note that this might not actually find the highest scoring match, as
   188  	// doing so could require a non-linear algorithm, depending on how the score
   189  	// is calculated.
   190  
   191  	pi = 0
   192  	p = m.pattern[pi]
   193  
   194  	const (
   195  		segStreak  = 1.0
   196  		wordStreak = 0.8
   197  		noStreak   = 0.6
   198  		perSegment = 0.2 // we count at most 3 segments above
   199  	)
   200  
   201  	streakBonus := noStreak
   202  	totScore := 0.0
   203  	for ii := uint8(start); ii < inputLen; ii++ {
   204  		r := m.inputBuffer[ii]
   205  		if r == p {
   206  			pi++
   207  			p = m.pattern[pi]
   208  			// Note: this could be optimized with some bit operations.
   209  			switch {
   210  			case m.roles[ii]&segmentStart != 0 && segStreak > streakBonus:
   211  				streakBonus = segStreak
   212  			case m.roles[ii]&wordStart != 0 && wordStreak > streakBonus:
   213  				streakBonus = wordStreak
   214  			}
   215  			finalChar := pi >= m.patternLen
   216  			// finalCost := 1.0
   217  			if finalChar && streakBonus > noStreak {
   218  				switch {
   219  				case ii == inputLen-1 || m.roles[ii+1]&segmentStart != 0:
   220  					// Full segment: no reduction
   221  				case m.roles[ii+1]&wordStart != 0:
   222  					streakBonus = wordStreak
   223  				default:
   224  					streakBonus = noStreak
   225  				}
   226  			}
   227  			totScore += streakBonus * (1.0 - float64(m.segments[ii])*perSegment)
   228  			if finalChar {
   229  				break
   230  			}
   231  		} else {
   232  			streakBonus = noStreak
   233  		}
   234  	}
   235  
   236  	return start, totScore / float64(m.patternLen)
   237  }