github.com/AlexanderZh/ahocorasick@v0.1.8/ahocorasick.go (about) 1 // Package ahocorasick implements the Aho-Corasick string matching algorithm for 2 // efficiently finding all instances of multiple patterns in a text. 3 package ahocorasick 4 5 import ( 6 "bytes" 7 "encoding/binary" 8 "fmt" 9 "io" 10 "sort" 11 ) 12 13 const ( 14 // leaf represents a leaf on the trie 15 // This must be <255 since the offsets used are in [0,255] 16 // This should only appear in the Base array since the Check array uses 17 // negative values to represent free states. 18 leaf = -1867 19 ) 20 21 type SWord struct { 22 Len uint64 23 Key uint64 24 } 25 26 // Matcher is the pattern matching state machine. 27 type Matcher struct { 28 base []int // base array in the double array trie 29 check []int // check array in the double array trie 30 fail []int // fail function 31 output [][]SWord // output function: originally [state][wordlen], replaced to tuple of {wordlen,workey} 32 } 33 34 // added function for byte serialization of compiled matcher 35 func (m *Matcher) Serialize() []byte { 36 var lenBase, lenCheck, lenFail, lenOutput uint64 37 38 lenBase = uint64(len(m.base)) 39 lenCheck = uint64(len(m.check)) 40 lenFail = uint64(len(m.fail)) 41 lenOutput = uint64(len(m.output)) 42 43 lenOutputEach := make([]uint64, lenOutput) 44 45 for i, v := range m.output { 46 lenOutputEach[i] = uint64(len(v)) 47 } 48 49 buf := new(bytes.Buffer) 50 err := binary.Write(buf, binary.LittleEndian, lenBase) 51 if err != nil { 52 fmt.Println("binary.Write failed for lenBase:", err) 53 } 54 err = binary.Write(buf, binary.LittleEndian, lenCheck) 55 if err != nil { 56 fmt.Println("binary.Write failed for lenCheck:", err) 57 } 58 err = binary.Write(buf, binary.LittleEndian, lenFail) 59 if err != nil { 60 fmt.Println("binary.Write failed for lenFail:", err) 61 } 62 err = binary.Write(buf, binary.LittleEndian, lenOutput) //2d array 63 if err != nil { 64 fmt.Println("binary.Write failed lenOutput:", err) 65 } 66 67 for i, v := range lenOutputEach { 68 err = binary.Write(buf, binary.LittleEndian, uint64(v)) 69 if err != nil { 70 fmt.Printf("binary.Write failed for lenOutputEach: %s at position %d", err, i) 71 } 72 } 73 74 for i, v := range m.base { 75 err = binary.Write(buf, binary.LittleEndian, uint64(v)) 76 if err != nil { 77 fmt.Printf("binary.Write failed: %s at base, position %d", err, i) 78 } 79 } 80 for i, v := range m.check { 81 err = binary.Write(buf, binary.LittleEndian, uint64(v)) 82 if err != nil { 83 fmt.Printf("binary.Write failed: %s at check, position %d", err, i) 84 } 85 } 86 for i, v := range m.fail { 87 err = binary.Write(buf, binary.LittleEndian, uint64(v)) 88 if err != nil { 89 fmt.Printf("binary.Write failed: %s at fail, position %d", err, i) 90 } 91 } 92 for i, v := range m.output { 93 for j, u := range v { 94 err = binary.Write(buf, binary.LittleEndian, u) 95 if err != nil { 96 fmt.Printf("binary.Write failed: %s at output, position %d, %d", err, i, j) 97 } 98 } 99 } 100 return (buf.Bytes()) 101 } 102 103 type DeserializeError struct{} 104 105 func (m *DeserializeError) Error() string { 106 return "Finite state machine is corrupted" 107 } 108 109 func Deserialize(data []byte) (m *Matcher, err error) { 110 m = new(Matcher) 111 112 totalLength := len(data) 113 114 if totalLength < 32 || totalLength%8 != 0 { 115 err = &DeserializeError{} 116 return 117 } 118 //reader := bytes.NewReader(data) 119 reader := bytes.NewReader(data) 120 121 var lenBase, lenCheck, lenFail, lenOutput uint64 122 123 err = binary.Read(reader, binary.LittleEndian, &lenBase) 124 if err != nil { 125 return 126 } 127 err = binary.Read(reader, binary.LittleEndian, &lenCheck) 128 if err != nil { 129 return 130 } 131 err = binary.Read(reader, binary.LittleEndian, &lenFail) 132 if err != nil { 133 return 134 } 135 err = binary.Read(reader, binary.LittleEndian, &lenOutput) 136 if err != nil { 137 return 138 } 139 140 lenOutputEach := make([]uint64, lenOutput) 141 142 if totalLength < 8*(4+int(lenOutput)) { 143 err = &DeserializeError{} 144 return 145 } 146 147 calculatedLength := 8 * (4 + int(lenOutput) + int(lenBase) + int(lenCheck) + int(lenFail)) 148 149 for i := 0; i < int(lenOutput); i++ { 150 err = binary.Read(reader, binary.LittleEndian, &(lenOutputEach[i])) 151 if err != nil { 152 return 153 } 154 calculatedLength += 16 * int(lenOutputEach[i]) 155 } 156 157 if calculatedLength != totalLength { 158 err = &DeserializeError{} 159 return 160 } 161 162 err = readToSlice(reader, lenBase, &m.base) 163 if err != nil { 164 return 165 } 166 err = readToSlice(reader, lenCheck, &m.check) 167 if err != nil { 168 return 169 } 170 err = readToSlice(reader, lenFail, &m.fail) 171 if err != nil { 172 return 173 } 174 m.output = make([][]SWord, lenOutput) 175 for i, v := range lenOutputEach { 176 err = readToSliceSWord(reader, v, &m.output[i]) 177 if err != nil { 178 return 179 } 180 } 181 182 return 183 } 184 185 func readToSlice(reader *bytes.Reader, len uint64, array *[]int) error { 186 *array = make([]int, len) 187 var item uint64 188 for i := 0; i < int(len); i++ { 189 err := binary.Read(reader, binary.LittleEndian, &item) 190 if err != nil { 191 return err 192 } 193 (*array)[i] = int(item) 194 } 195 return nil 196 } 197 198 func readToSliceSWord(reader *bytes.Reader, len uint64, array *[]SWord) error { 199 *array = make([]SWord, len) 200 var item uint64 201 var err error 202 for i := 0; i < int(len); i++ { 203 err = binary.Read(reader, binary.LittleEndian, &item) 204 if err != nil { 205 return err 206 } 207 (*array)[i].Len = item 208 err = binary.Read(reader, binary.LittleEndian, &item) 209 if err != nil { 210 return err 211 } 212 (*array)[i].Key = item 213 } 214 return nil 215 } 216 217 func (m *Matcher) String() string { 218 return fmt.Sprintf(` 219 Base: %v 220 Check: %v 221 Fail: %v 222 Output: %v 223 `, m.base, m.check, m.fail, m.output) 224 } 225 226 type byteSliceSlice [][]byte 227 228 func (bss byteSliceSlice) Len() int { return len(bss) } 229 func (bss byteSliceSlice) Less(i, j int) bool { return bytes.Compare(bss[i], bss[j]) < 1 } 230 func (bss byteSliceSlice) Swap(i, j int) { bss[i], bss[j] = bss[j], bss[i] } 231 232 func compile(words [][]byte) *Matcher { 233 m := new(Matcher) 234 m.base = make([]int, 2048)[:1] 235 m.check = make([]int, 2048)[:1] 236 m.fail = make([]int, 2048)[:1] 237 m.output = make([][]SWord, 2048)[:1] 238 239 sort.Sort(byteSliceSlice(words)) 240 241 // Represents a node in the implicit trie of words 242 type trienode struct { 243 state int 244 depth int 245 start int 246 end int 247 } 248 queue := make([]trienode, 2048)[:1] 249 queue[0] = trienode{0, 0, 0, len(words)} 250 251 for len(queue) > 0 { 252 node := queue[0] 253 queue = queue[1:] 254 255 if node.end <= node.start { 256 m.base[node.state] = leaf 257 continue 258 } 259 260 var edges []byte 261 for i := node.start; i < node.end; i++ { 262 if len(edges) == 0 || edges[len(edges)-1] != words[i][node.depth] { 263 edges = append(edges, words[i][node.depth]) 264 } 265 } 266 267 // Calculate a suitable Base value where each edge will fit into the 268 // double array trie 269 base := m.findBase(edges) 270 m.base[node.state] = base 271 272 i := node.start 273 for _, edge := range edges { 274 offset := int(edge) 275 newState := base + offset 276 277 m.occupyState(newState, node.state) 278 279 // level 0 and level 1 should fail to state 0 280 if node.depth > 0 { 281 m.setFailState(newState, node.state, offset) 282 } 283 m.unionFailOutput(newState, m.fail[newState]) 284 285 // Add the child nodes to the queue to continue down the BFS 286 newnode := trienode{newState, node.depth + 1, i, i} 287 for { 288 if newnode.depth >= len(words[i]) { 289 m.output[newState] = append(m.output[newState], SWord{uint64(len(words[i])), uint64(i)}) 290 newnode.start++ 291 } 292 newnode.end++ 293 294 i++ 295 if i >= node.end || words[i][node.depth] != edge { 296 break 297 } 298 } 299 queue = append(queue, newnode) 300 } 301 } 302 303 return m 304 } 305 306 // CompileByteSlices compiles a Matcher from a slice of byte slices. This Matcher can be 307 // used to find occurrences of each pattern in a text. 308 func CompileByteSlices(words [][]byte) *Matcher { 309 return compile(words) 310 } 311 312 // CompileStrings compiles a Matcher from a slice of strings. This Matcher can 313 // be used to find occurrences of each pattern in a text. 314 func CompileStrings(words []string) *Matcher { 315 var wordByteSlices [][]byte 316 for _, word := range words { 317 wordByteSlices = append(wordByteSlices, []byte(word)) 318 } 319 return compile(wordByteSlices) 320 } 321 322 // occupyState will correctly occupy state so it maintains the 323 // index=check[base[index]+offset] identity. It will also update the 324 // bidirectional link of free states correctly. 325 // Note: This MUST be used instead of simply modifying the check array directly 326 // which is break the bidirectional link of free states. 327 func (m *Matcher) occupyState(state, parentState int) { 328 firstFreeState := m.firstFreeState() 329 lastFreeState := m.lastFreeState() 330 if firstFreeState == lastFreeState { 331 m.check[0] = 0 332 } else { 333 switch state { 334 case firstFreeState: 335 next := -1 * m.check[state] 336 m.check[0] = -1 * next 337 m.base[next] = m.base[state] 338 case lastFreeState: 339 prev := -1 * m.base[state] 340 m.base[firstFreeState] = -1 * prev 341 m.check[prev] = -1 342 default: 343 next := -1 * m.check[state] 344 prev := -1 * m.base[state] 345 m.check[prev] = -1 * next 346 m.base[next] = -1 * prev 347 } 348 } 349 m.check[state] = parentState 350 m.base[state] = leaf 351 } 352 353 // setFailState sets the output of the fail function for input state. It will 354 // traverse up the fail states of it's ancestors until it reaches a fail state 355 // with a transition for offset. 356 func (m *Matcher) setFailState(state, parentState, offset int) { 357 failState := m.fail[parentState] 358 for { 359 if m.hasEdge(failState, offset) { 360 m.fail[state] = m.base[failState] + offset 361 break 362 } 363 if failState == 0 { 364 break 365 } 366 failState = m.fail[failState] 367 } 368 } 369 370 // unionFailOutput unions the output function for failState with the output 371 // function for state and sets the result as the output function for state. 372 // This allows us to match substrings, commenting out this body would match 373 // every word that is not a substring. 374 func (m *Matcher) unionFailOutput(state, failState int) { 375 m.output[state] = append([]SWord{}, m.output[failState]...) 376 } 377 378 // findBase finds a base value which has free states in the positions that 379 // correspond to each edge transition in edges. If this does not exist, then 380 // base and check (and the fail array for consistency) will be extended just 381 // enough to fit each transition. 382 // The extension will maintain the bidirectional link of free states. 383 func (m *Matcher) findBase(edges []byte) int { 384 if len(edges) == 0 { 385 return leaf 386 } 387 388 min := int(edges[0]) 389 max := int(edges[len(edges)-1]) 390 width := max - min 391 freeState := m.firstFreeState() 392 for freeState != -1 { 393 valid := true 394 for _, e := range edges[1:] { 395 state := freeState + int(e) - min 396 if state >= len(m.check) { 397 break 398 } else if m.check[state] >= 0 { 399 valid = false 400 break 401 } 402 } 403 404 if valid { 405 if freeState+width >= len(m.check) { 406 m.increaseSize(width - len(m.check) + freeState + 1) 407 } 408 return freeState - min 409 } 410 411 freeState = m.nextFreeState(freeState) 412 } 413 freeState = len(m.check) 414 m.increaseSize(width + 1) 415 return freeState - min 416 } 417 418 // increaseSize increases the size of base, check, and fail to ensure they 419 // remain the same size. 420 // It also sets the default value for these new unoccupied states which form 421 // bidirectional links to allow fast access to empty states. These 422 // bidirectional links only pertain to base and check. 423 // 424 // Example: 425 // m: 426 // 427 // base: [ 5 0 0 ] 428 // check: [ 0 0 0 ] 429 // 430 // increaseSize(3): 431 // 432 // base: [ 5 0 0 -5 -3 -4 ] 433 // check: [ -3 0 0 -4 -5 -1 ] 434 // 435 // increaseSize(3): 436 // 437 // base: [ 5 0 0 -8 -3 -4 -5 -6 -7] 438 // check: [ -3 0 0 -4 -5 -6 -7 -8 -1] 439 // 440 // m: 441 // 442 // base: [ 5 0 0 ] 443 // check: [ 0 0 0 ] 444 // 445 // increaseSize(1): 446 // 447 // base: [ 5 0 0 -3 ] 448 // check: [ -3 0 0 -1 ] 449 // 450 // increaseSize(1): 451 // 452 // base: [ 5 0 0 -4 -3 ] 453 // check: [ -3 0 0 -4 -1 ] 454 // 455 // increaseSize(1): 456 // 457 // base: [ 5 0 0 -5 -3 -4 ] 458 // check: [ -3 0 0 -4 -5 -1 ] 459 func (m *Matcher) increaseSize(dsize int) { 460 if dsize == 0 { 461 return 462 } 463 464 m.base = append(m.base, make([]int, dsize)...) 465 m.check = append(m.check, make([]int, dsize)...) 466 m.fail = append(m.fail, make([]int, dsize)...) 467 m.output = append(m.output, make([][]SWord, dsize)...) 468 469 lastFreeState := m.lastFreeState() 470 firstFreeState := m.firstFreeState() 471 for i := len(m.check) - dsize; i < len(m.check); i++ { 472 if lastFreeState == -1 { 473 m.check[0] = -1 * i 474 m.base[i] = -1 * i 475 m.check[i] = -1 476 firstFreeState = i 477 lastFreeState = i 478 } else { 479 m.base[i] = -1 * lastFreeState 480 m.check[i] = -1 481 m.base[firstFreeState] = -1 * i 482 m.check[lastFreeState] = -1 * i 483 lastFreeState = i 484 } 485 } 486 } 487 488 // nextFreeState uses the nature of the bidirectional link to determine the 489 // closest free state at a larger index. Since the check array holds the 490 // negative index of the next free state, except for the last free state which 491 // has a value of -1, negating this value is the next free state. 492 func (m *Matcher) nextFreeState(curFreeState int) int { 493 nextState := -1 * m.check[curFreeState] 494 495 // state 1 can never be a free state. 496 if nextState == 1 { 497 return -1 498 } 499 500 return nextState 501 } 502 503 // firstFreeState uses the first value in the check array which points to the 504 // first free state. A value of 0 means there are no free states and -1 is 505 // returned. 506 func (m *Matcher) firstFreeState() int { 507 state := m.check[0] 508 if state != 0 { 509 return -1 * state 510 } 511 return -1 512 } 513 514 // lastFreeState uses the base value of the first free state which points the 515 // last free state. 516 func (m *Matcher) lastFreeState() int { 517 firstFree := m.firstFreeState() 518 if firstFree != -1 { 519 return -1 * m.base[firstFree] 520 } 521 return -1 522 } 523 524 // hasEdge determines if the fromState has a transition for offset. 525 func (m *Matcher) hasEdge(fromState, offset int) bool { 526 toState := m.base[fromState] + offset 527 return toState > 0 && toState < len(m.check) && m.check[toState] == fromState 528 } 529 530 // Match represents a matched pattern in the text 531 type Match struct { 532 Word []byte // the matched pattern 533 Index int // the start index of the match 534 } 535 536 type Matches interface { 537 Append(key int, position int) 538 Count() int 539 } 540 541 func (m *Matcher) findAll(text []byte) []*Match { 542 var matches []*Match 543 state := 0 544 for i, b := range text { 545 offset := int(b) 546 for state != 0 && !m.hasEdge(state, offset) { 547 state = m.fail[state] 548 } 549 550 if m.hasEdge(state, offset) { 551 state = m.base[state] + offset 552 } 553 for _, item := range m.output[state] { 554 matches = append(matches, &Match{text[i-int(item.Len)+1 : i+1], i - int(item.Len) + 1}) 555 } 556 } 557 return matches 558 } 559 560 func (m *Matcher) findAllReader(reader io.Reader, matches Matches) { 561 state := 0 562 buf := make([]byte, 1) 563 n, err := reader.Read(buf) 564 b := int(buf[0]) 565 i := 1 566 for err == nil && n == 1 { 567 offset := b 568 for state != 0 && !m.hasEdge(state, offset) { 569 state = m.fail[state] 570 } 571 572 if m.hasEdge(state, offset) { 573 state = m.base[state] + offset 574 } 575 for _, item := range m.output[state] { 576 matches.Append(i, int(item.Key)) 577 } 578 n, err = reader.Read(buf) 579 b = int(buf[0]) 580 i++ 581 } 582 } 583 584 // FindAllByteSlice finds all instances of the patterns in the text. 585 func (m *Matcher) FindAllByteSlice(text []byte) (matches []*Match) { 586 return m.findAll(text) 587 } 588 589 func (m *Matcher) FindAllByteReader(reader io.Reader, matches Matches) { 590 m.findAllReader(reader, matches) 591 } 592 593 // FindAllString finds all instances of the patterns in the text. 594 func (m *Matcher) FindAllString(text string) []*Match { 595 return m.FindAllByteSlice([]byte(text)) 596 }