github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/internal/fuzzy/matcher.go (about) 1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package fuzzy implements a fuzzy matching algorithm. 6 package fuzzy 7 8 import ( 9 "bytes" 10 "fmt" 11 ) 12 13 const ( 14 // MaxInputSize is the maximum size of the input scored against the fuzzy matcher. Longer inputs 15 // will be truncated to this size. 16 MaxInputSize = 127 17 // MaxPatternSize is the maximum size of the pattern used to construct the fuzzy matcher. Longer 18 // inputs are truncated to this size. 19 MaxPatternSize = 63 20 ) 21 22 type scoreVal int 23 24 func (s scoreVal) val() int { 25 return int(s) >> 1 26 } 27 28 func (s scoreVal) prevK() int { 29 return int(s) & 1 30 } 31 32 func score(val int, prevK int /*0 or 1*/) scoreVal { 33 return scoreVal(val<<1 + prevK) 34 } 35 36 // Matcher implements a fuzzy matching algorithm for scoring candidates against a pattern. 37 // The matcher does not support parallel usage. 38 type Matcher struct { 39 pattern string 40 patternLower []byte // lower-case version of the pattern 41 patternShort []byte // first characters of the pattern 42 caseSensitive bool // set if the pattern is mix-cased 43 44 patternRoles []RuneRole // the role of each character in the pattern 45 roles []RuneRole // the role of each character in the tested string 46 47 scores [MaxInputSize + 1][MaxPatternSize + 1][2]scoreVal 48 49 scoreScale float32 50 51 lastCandidateLen int // in bytes 52 lastCandidateMatched bool 53 54 // Reusable buffers to avoid allocating for every candidate. 55 // - inputBuf stores the concatenated input chunks 56 // - lowerBuf stores the last candidate in lower-case 57 // - rolesBuf stores the calculated roles for each rune in the last 58 // candidate. 59 inputBuf [MaxInputSize]byte 60 lowerBuf [MaxInputSize]byte 61 rolesBuf [MaxInputSize]RuneRole 62 } 63 64 func (m *Matcher) bestK(i, j int) int { 65 if m.scores[i][j][0].val() < m.scores[i][j][1].val() { 66 return 1 67 } 68 return 0 69 } 70 71 // NewMatcher returns a new fuzzy matcher for scoring candidates against the provided pattern. 72 func NewMatcher(pattern string) *Matcher { 73 if len(pattern) > MaxPatternSize { 74 pattern = pattern[:MaxPatternSize] 75 } 76 77 m := &Matcher{ 78 pattern: pattern, 79 patternLower: toLower([]byte(pattern), nil), 80 } 81 82 for i, c := range m.patternLower { 83 if pattern[i] != c { 84 m.caseSensitive = true 85 break 86 } 87 } 88 89 if len(pattern) > 3 { 90 m.patternShort = m.patternLower[:3] 91 } else { 92 m.patternShort = m.patternLower 93 } 94 95 m.patternRoles = RuneRoles([]byte(pattern), nil) 96 97 if len(pattern) > 0 { 98 maxCharScore := 4 99 m.scoreScale = 1 / float32(maxCharScore*len(pattern)) 100 } 101 102 return m 103 } 104 105 // Score returns the score returned by matching the candidate to the pattern. 106 // This is not designed for parallel use. Multiple candidates must be scored sequentially. 107 // Returns a score between 0 and 1 (0 - no match, 1 - perfect match). 108 func (m *Matcher) Score(candidate string) float32 { 109 return m.ScoreChunks([]string{candidate}) 110 } 111 112 func (m *Matcher) ScoreChunks(chunks []string) float32 { 113 candidate := fromChunks(chunks, m.inputBuf[:]) 114 if len(candidate) > MaxInputSize { 115 candidate = candidate[:MaxInputSize] 116 } 117 lower := toLower(candidate, m.lowerBuf[:]) 118 m.lastCandidateLen = len(candidate) 119 120 if len(m.pattern) == 0 { 121 // Empty patterns perfectly match candidates. 122 return 1 123 } 124 125 if m.match(candidate, lower) { 126 sc := m.computeScore(candidate, lower) 127 if sc > minScore/2 && !m.poorMatch() { 128 m.lastCandidateMatched = true 129 if len(m.pattern) == len(candidate) { 130 // Perfect match. 131 return 1 132 } 133 134 if sc < 0 { 135 sc = 0 136 } 137 normalizedScore := float32(sc) * m.scoreScale 138 if normalizedScore > 1 { 139 normalizedScore = 1 140 } 141 142 return normalizedScore 143 } 144 } 145 146 m.lastCandidateMatched = false 147 return 0 148 } 149 150 const minScore = -10000 151 152 // MatchedRanges returns matches ranges for the last scored string as a flattened array of 153 // [begin, end) byte offset pairs. 154 func (m *Matcher) MatchedRanges() []int { 155 if len(m.pattern) == 0 || !m.lastCandidateMatched { 156 return nil 157 } 158 i, j := m.lastCandidateLen, len(m.pattern) 159 if m.scores[i][j][0].val() < minScore/2 && m.scores[i][j][1].val() < minScore/2 { 160 return nil 161 } 162 163 var ret []int 164 k := m.bestK(i, j) 165 for i > 0 { 166 take := (k == 1) 167 k = m.scores[i][j][k].prevK() 168 if take { 169 if len(ret) == 0 || ret[len(ret)-1] != i { 170 ret = append(ret, i) 171 ret = append(ret, i-1) 172 } else { 173 ret[len(ret)-1] = i - 1 174 } 175 j-- 176 } 177 i-- 178 } 179 // Reverse slice. 180 for i := 0; i < len(ret)/2; i++ { 181 ret[i], ret[len(ret)-1-i] = ret[len(ret)-1-i], ret[i] 182 } 183 return ret 184 } 185 186 func (m *Matcher) match(candidate []byte, candidateLower []byte) bool { 187 i, j := 0, 0 188 for ; i < len(candidateLower) && j < len(m.patternLower); i++ { 189 if candidateLower[i] == m.patternLower[j] { 190 j++ 191 } 192 } 193 if j != len(m.patternLower) { 194 return false 195 } 196 197 // The input passes the simple test against pattern, so it is time to classify its characters. 198 // Character roles are used below to find the last segment. 199 m.roles = RuneRoles(candidate, m.rolesBuf[:]) 200 201 return true 202 } 203 204 func (m *Matcher) computeScore(candidate []byte, candidateLower []byte) int { 205 pattLen, candLen := len(m.pattern), len(candidate) 206 207 for j := 0; j <= len(m.pattern); j++ { 208 m.scores[0][j][0] = minScore << 1 209 m.scores[0][j][1] = minScore << 1 210 } 211 m.scores[0][0][0] = score(0, 0) // Start with 0. 212 213 segmentsLeft, lastSegStart := 1, 0 214 for i := 0; i < candLen; i++ { 215 if m.roles[i] == RSep { 216 segmentsLeft++ 217 lastSegStart = i + 1 218 } 219 } 220 221 // A per-character bonus for a consecutive match. 222 consecutiveBonus := 2 223 wordIdx := 0 // Word count within segment. 224 for i := 1; i <= candLen; i++ { 225 226 role := m.roles[i-1] 227 isHead := role == RHead 228 229 if isHead { 230 wordIdx++ 231 } else if role == RSep && segmentsLeft > 1 { 232 wordIdx = 0 233 segmentsLeft-- 234 } 235 236 var skipPenalty int 237 if i == 1 || (i-1) == lastSegStart { 238 // Skipping the start of first or last segment. 239 skipPenalty++ 240 } 241 242 for j := 0; j <= pattLen; j++ { 243 // By default, we don't have a match. Fill in the skip data. 244 m.scores[i][j][1] = minScore << 1 245 246 // Compute the skip score. 247 k := 0 248 if m.scores[i-1][j][0].val() < m.scores[i-1][j][1].val() { 249 k = 1 250 } 251 252 skipScore := m.scores[i-1][j][k].val() 253 // Do not penalize missing characters after the last matched segment. 254 if j != pattLen { 255 skipScore -= skipPenalty 256 } 257 m.scores[i][j][0] = score(skipScore, k) 258 259 if j == 0 || candidateLower[i-1] != m.patternLower[j-1] { 260 // Not a match. 261 continue 262 } 263 pRole := m.patternRoles[j-1] 264 265 if role == RTail && pRole == RHead { 266 if j > 1 { 267 // Not a match: a head in the pattern matches a tail character in the candidate. 268 continue 269 } 270 // Special treatment for the first character of the pattern. We allow 271 // matches in the middle of a word if they are long enough, at least 272 // min(3, pattern.length) characters. 273 if !bytes.HasPrefix(candidateLower[i-1:], m.patternShort) { 274 continue 275 } 276 } 277 278 // Compute the char score. 279 var charScore int 280 // Bonus 1: the char is in the candidate's last segment. 281 if segmentsLeft <= 1 { 282 charScore++ 283 } 284 // Bonus 2: Case match or a Head in the pattern aligns with one in the word. 285 // Single-case patterns lack segmentation signals and we assume any character 286 // can be a head of a segment. 287 if candidate[i-1] == m.pattern[j-1] || role == RHead && (!m.caseSensitive || pRole == RHead) { 288 charScore++ 289 } 290 291 // Penalty 1: pattern char is Head, candidate char is Tail. 292 if role == RTail && pRole == RHead { 293 charScore-- 294 } 295 // Penalty 2: first pattern character matched in the middle of a word. 296 if j == 1 && role == RTail { 297 charScore -= 4 298 } 299 300 // Third dimension encodes whether there is a gap between the previous match and the current 301 // one. 302 for k := 0; k < 2; k++ { 303 sc := m.scores[i-1][j-1][k].val() + charScore 304 305 isConsecutive := k == 1 || i-1 == 0 || i-1 == lastSegStart 306 if isConsecutive { 307 // Bonus 3: a consecutive match. First character match also gets a bonus to 308 // ensure prefix final match score normalizes to 1.0. 309 // Logically, this is a part of charScore, but we have to compute it here because it 310 // only applies for consecutive matches (k == 1). 311 sc += consecutiveBonus 312 } 313 if k == 0 { 314 // Penalty 3: Matching inside a segment (and previous char wasn't matched). Penalize for the lack 315 // of alignment. 316 if role == RTail || role == RUCTail { 317 sc -= 3 318 } 319 } 320 321 if sc > m.scores[i][j][1].val() { 322 m.scores[i][j][1] = score(sc, k) 323 } 324 } 325 } 326 } 327 328 result := m.scores[len(candidate)][len(m.pattern)][m.bestK(len(candidate), len(m.pattern))].val() 329 330 return result 331 } 332 333 // ScoreTable returns the score table computed for the provided candidate. Used only for debugging. 334 func (m *Matcher) ScoreTable(candidate string) string { 335 var buf bytes.Buffer 336 337 var line1, line2, separator bytes.Buffer 338 line1.WriteString("\t") 339 line2.WriteString("\t") 340 for j := 0; j < len(m.pattern); j++ { 341 line1.WriteString(fmt.Sprintf("%c\t\t", m.pattern[j])) 342 separator.WriteString("----------------") 343 } 344 345 buf.WriteString(line1.String()) 346 buf.WriteString("\n") 347 buf.WriteString(separator.String()) 348 buf.WriteString("\n") 349 350 for i := 1; i <= len(candidate); i++ { 351 line1.Reset() 352 line2.Reset() 353 354 line1.WriteString(fmt.Sprintf("%c\t", candidate[i-1])) 355 line2.WriteString("\t") 356 357 for j := 1; j <= len(m.pattern); j++ { 358 line1.WriteString(fmt.Sprintf("M%6d(%c)\t", m.scores[i][j][0].val(), dir(m.scores[i][j][0].prevK()))) 359 line2.WriteString(fmt.Sprintf("H%6d(%c)\t", m.scores[i][j][1].val(), dir(m.scores[i][j][1].prevK()))) 360 } 361 buf.WriteString(line1.String()) 362 buf.WriteString("\n") 363 buf.WriteString(line2.String()) 364 buf.WriteString("\n") 365 buf.WriteString(separator.String()) 366 buf.WriteString("\n") 367 } 368 369 return buf.String() 370 } 371 372 func dir(prevK int) rune { 373 if prevK == 0 { 374 return 'M' 375 } 376 return 'H' 377 } 378 379 func (m *Matcher) poorMatch() bool { 380 if len(m.pattern) < 2 { 381 return false 382 } 383 384 i, j := m.lastCandidateLen, len(m.pattern) 385 k := m.bestK(i, j) 386 387 var counter, len int 388 for i > 0 { 389 take := (k == 1) 390 k = m.scores[i][j][k].prevK() 391 if take { 392 len++ 393 if k == 0 && len < 3 && m.roles[i-1] == RTail { 394 // Short match in the middle of a word 395 counter++ 396 if counter > 1 { 397 return true 398 } 399 } 400 j-- 401 } else { 402 len = 0 403 } 404 i-- 405 } 406 return false 407 } 408 409 // BestMatch returns the name most similar to the 410 // pattern, using fuzzy matching, or the empty string. 411 func BestMatch(pattern string, names []string) string { 412 fuzz := NewMatcher(pattern) 413 best := "" 414 highScore := float32(0) // minimum score is 0 (no match) 415 for _, name := range names { 416 // TODO: Improve scoring algorithm. 417 score := fuzz.Score(name) 418 if score > highScore { 419 highScore = score 420 best = name 421 } else if score == 0 { 422 // Order matters in the fuzzy matching algorithm. If we find no match 423 // when matching the target to the identifier, try matching the identifier 424 // to the target. 425 revFuzz := NewMatcher(name) 426 revScore := revFuzz.Score(pattern) 427 if revScore > highScore { 428 highScore = revScore 429 best = name 430 } 431 } 432 } 433 return best 434 }