github.com/AlexanderZh/ahocorasick@v0.1.8/ahocorasick.go (about)

     1  // Package ahocorasick implements the Aho-Corasick string matching algorithm for
     2  // efficiently finding all instances of multiple patterns in a text.
     3  package ahocorasick
     4  
     5  import (
     6  	"bytes"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"io"
    10  	"sort"
    11  )
    12  
    13  const (
    14  	// leaf represents a leaf on the trie
    15  	// This must be <255 since the offsets used are in [0,255]
    16  	// This should only appear in the Base array since the Check array uses
    17  	// negative values to represent free states.
    18  	leaf = -1867
    19  )
    20  
    21  type SWord struct {
    22  	Len uint64
    23  	Key uint64
    24  }
    25  
    26  // Matcher is the pattern matching state machine.
    27  type Matcher struct {
    28  	base   []int     // base array in the double array trie
    29  	check  []int     // check array in the double array trie
    30  	fail   []int     // fail function
    31  	output [][]SWord // output function: originally [state][wordlen], replaced to tuple of {wordlen,workey}
    32  }
    33  
    34  // added function for byte serialization of compiled matcher
    35  func (m *Matcher) Serialize() []byte {
    36  	var lenBase, lenCheck, lenFail, lenOutput uint64
    37  
    38  	lenBase = uint64(len(m.base))
    39  	lenCheck = uint64(len(m.check))
    40  	lenFail = uint64(len(m.fail))
    41  	lenOutput = uint64(len(m.output))
    42  
    43  	lenOutputEach := make([]uint64, lenOutput)
    44  
    45  	for i, v := range m.output {
    46  		lenOutputEach[i] = uint64(len(v))
    47  	}
    48  
    49  	buf := new(bytes.Buffer)
    50  	err := binary.Write(buf, binary.LittleEndian, lenBase)
    51  	if err != nil {
    52  		fmt.Println("binary.Write failed for lenBase:", err)
    53  	}
    54  	err = binary.Write(buf, binary.LittleEndian, lenCheck)
    55  	if err != nil {
    56  		fmt.Println("binary.Write failed for lenCheck:", err)
    57  	}
    58  	err = binary.Write(buf, binary.LittleEndian, lenFail)
    59  	if err != nil {
    60  		fmt.Println("binary.Write failed for lenFail:", err)
    61  	}
    62  	err = binary.Write(buf, binary.LittleEndian, lenOutput) //2d array
    63  	if err != nil {
    64  		fmt.Println("binary.Write failed lenOutput:", err)
    65  	}
    66  
    67  	for i, v := range lenOutputEach {
    68  		err = binary.Write(buf, binary.LittleEndian, uint64(v))
    69  		if err != nil {
    70  			fmt.Printf("binary.Write failed for lenOutputEach: %s at position %d", err, i)
    71  		}
    72  	}
    73  
    74  	for i, v := range m.base {
    75  		err = binary.Write(buf, binary.LittleEndian, uint64(v))
    76  		if err != nil {
    77  			fmt.Printf("binary.Write failed: %s at base, position %d", err, i)
    78  		}
    79  	}
    80  	for i, v := range m.check {
    81  		err = binary.Write(buf, binary.LittleEndian, uint64(v))
    82  		if err != nil {
    83  			fmt.Printf("binary.Write failed: %s at check, position %d", err, i)
    84  		}
    85  	}
    86  	for i, v := range m.fail {
    87  		err = binary.Write(buf, binary.LittleEndian, uint64(v))
    88  		if err != nil {
    89  			fmt.Printf("binary.Write failed: %s at fail, position %d", err, i)
    90  		}
    91  	}
    92  	for i, v := range m.output {
    93  		for j, u := range v {
    94  			err = binary.Write(buf, binary.LittleEndian, u)
    95  			if err != nil {
    96  				fmt.Printf("binary.Write failed: %s at output, position %d, %d", err, i, j)
    97  			}
    98  		}
    99  	}
   100  	return (buf.Bytes())
   101  }
   102  
   103  type DeserializeError struct{}
   104  
   105  func (m *DeserializeError) Error() string {
   106  	return "Finite state machine is corrupted"
   107  }
   108  
   109  func Deserialize(data []byte) (m *Matcher, err error) {
   110  	m = new(Matcher)
   111  
   112  	totalLength := len(data)
   113  
   114  	if totalLength < 32 || totalLength%8 != 0 {
   115  		err = &DeserializeError{}
   116  		return
   117  	}
   118  	//reader := bytes.NewReader(data)
   119  	reader := bytes.NewReader(data)
   120  
   121  	var lenBase, lenCheck, lenFail, lenOutput uint64
   122  
   123  	err = binary.Read(reader, binary.LittleEndian, &lenBase)
   124  	if err != nil {
   125  		return
   126  	}
   127  	err = binary.Read(reader, binary.LittleEndian, &lenCheck)
   128  	if err != nil {
   129  		return
   130  	}
   131  	err = binary.Read(reader, binary.LittleEndian, &lenFail)
   132  	if err != nil {
   133  		return
   134  	}
   135  	err = binary.Read(reader, binary.LittleEndian, &lenOutput)
   136  	if err != nil {
   137  		return
   138  	}
   139  
   140  	lenOutputEach := make([]uint64, lenOutput)
   141  
   142  	if totalLength < 8*(4+int(lenOutput)) {
   143  		err = &DeserializeError{}
   144  		return
   145  	}
   146  
   147  	calculatedLength := 8 * (4 + int(lenOutput) + int(lenBase) + int(lenCheck) + int(lenFail))
   148  
   149  	for i := 0; i < int(lenOutput); i++ {
   150  		err = binary.Read(reader, binary.LittleEndian, &(lenOutputEach[i]))
   151  		if err != nil {
   152  			return
   153  		}
   154  		calculatedLength += 16 * int(lenOutputEach[i])
   155  	}
   156  
   157  	if calculatedLength != totalLength {
   158  		err = &DeserializeError{}
   159  		return
   160  	}
   161  
   162  	err = readToSlice(reader, lenBase, &m.base)
   163  	if err != nil {
   164  		return
   165  	}
   166  	err = readToSlice(reader, lenCheck, &m.check)
   167  	if err != nil {
   168  		return
   169  	}
   170  	err = readToSlice(reader, lenFail, &m.fail)
   171  	if err != nil {
   172  		return
   173  	}
   174  	m.output = make([][]SWord, lenOutput)
   175  	for i, v := range lenOutputEach {
   176  		err = readToSliceSWord(reader, v, &m.output[i])
   177  		if err != nil {
   178  			return
   179  		}
   180  	}
   181  
   182  	return
   183  }
   184  
   185  func readToSlice(reader *bytes.Reader, len uint64, array *[]int) error {
   186  	*array = make([]int, len)
   187  	var item uint64
   188  	for i := 0; i < int(len); i++ {
   189  		err := binary.Read(reader, binary.LittleEndian, &item)
   190  		if err != nil {
   191  			return err
   192  		}
   193  		(*array)[i] = int(item)
   194  	}
   195  	return nil
   196  }
   197  
   198  func readToSliceSWord(reader *bytes.Reader, len uint64, array *[]SWord) error {
   199  	*array = make([]SWord, len)
   200  	var item uint64
   201  	var err error
   202  	for i := 0; i < int(len); i++ {
   203  		err = binary.Read(reader, binary.LittleEndian, &item)
   204  		if err != nil {
   205  			return err
   206  		}
   207  		(*array)[i].Len = item
   208  		err = binary.Read(reader, binary.LittleEndian, &item)
   209  		if err != nil {
   210  			return err
   211  		}
   212  		(*array)[i].Key = item
   213  	}
   214  	return nil
   215  }
   216  
   217  func (m *Matcher) String() string {
   218  	return fmt.Sprintf(`
   219  Base:   %v
   220  Check:  %v
   221  Fail:   %v
   222  Output: %v
   223  `, m.base, m.check, m.fail, m.output)
   224  }
   225  
   226  type byteSliceSlice [][]byte
   227  
   228  func (bss byteSliceSlice) Len() int           { return len(bss) }
   229  func (bss byteSliceSlice) Less(i, j int) bool { return bytes.Compare(bss[i], bss[j]) < 1 }
   230  func (bss byteSliceSlice) Swap(i, j int)      { bss[i], bss[j] = bss[j], bss[i] }
   231  
   232  func compile(words [][]byte) *Matcher {
   233  	m := new(Matcher)
   234  	m.base = make([]int, 2048)[:1]
   235  	m.check = make([]int, 2048)[:1]
   236  	m.fail = make([]int, 2048)[:1]
   237  	m.output = make([][]SWord, 2048)[:1]
   238  
   239  	sort.Sort(byteSliceSlice(words))
   240  
   241  	// Represents a node in the implicit trie of words
   242  	type trienode struct {
   243  		state int
   244  		depth int
   245  		start int
   246  		end   int
   247  	}
   248  	queue := make([]trienode, 2048)[:1]
   249  	queue[0] = trienode{0, 0, 0, len(words)}
   250  
   251  	for len(queue) > 0 {
   252  		node := queue[0]
   253  		queue = queue[1:]
   254  
   255  		if node.end <= node.start {
   256  			m.base[node.state] = leaf
   257  			continue
   258  		}
   259  
   260  		var edges []byte
   261  		for i := node.start; i < node.end; i++ {
   262  			if len(edges) == 0 || edges[len(edges)-1] != words[i][node.depth] {
   263  				edges = append(edges, words[i][node.depth])
   264  			}
   265  		}
   266  
   267  		// Calculate a suitable Base value where each edge will fit into the
   268  		// double array trie
   269  		base := m.findBase(edges)
   270  		m.base[node.state] = base
   271  
   272  		i := node.start
   273  		for _, edge := range edges {
   274  			offset := int(edge)
   275  			newState := base + offset
   276  
   277  			m.occupyState(newState, node.state)
   278  
   279  			// level 0 and level 1 should fail to state 0
   280  			if node.depth > 0 {
   281  				m.setFailState(newState, node.state, offset)
   282  			}
   283  			m.unionFailOutput(newState, m.fail[newState])
   284  
   285  			// Add the child nodes to the queue to continue down the BFS
   286  			newnode := trienode{newState, node.depth + 1, i, i}
   287  			for {
   288  				if newnode.depth >= len(words[i]) {
   289  					m.output[newState] = append(m.output[newState], SWord{uint64(len(words[i])), uint64(i)})
   290  					newnode.start++
   291  				}
   292  				newnode.end++
   293  
   294  				i++
   295  				if i >= node.end || words[i][node.depth] != edge {
   296  					break
   297  				}
   298  			}
   299  			queue = append(queue, newnode)
   300  		}
   301  	}
   302  
   303  	return m
   304  }
   305  
   306  // CompileByteSlices compiles a Matcher from a slice of byte slices. This Matcher can be
   307  // used to find occurrences of each pattern in a text.
   308  func CompileByteSlices(words [][]byte) *Matcher {
   309  	return compile(words)
   310  }
   311  
   312  // CompileStrings compiles a Matcher from a slice of strings. This Matcher can
   313  // be used to find occurrences of each pattern in a text.
   314  func CompileStrings(words []string) *Matcher {
   315  	var wordByteSlices [][]byte
   316  	for _, word := range words {
   317  		wordByteSlices = append(wordByteSlices, []byte(word))
   318  	}
   319  	return compile(wordByteSlices)
   320  }
   321  
   322  // occupyState will correctly occupy state so it maintains the
   323  // index=check[base[index]+offset] identity. It will also update the
   324  // bidirectional link of free states correctly.
   325  // Note: This MUST be used instead of simply modifying the check array directly
   326  // which is break the bidirectional link of free states.
   327  func (m *Matcher) occupyState(state, parentState int) {
   328  	firstFreeState := m.firstFreeState()
   329  	lastFreeState := m.lastFreeState()
   330  	if firstFreeState == lastFreeState {
   331  		m.check[0] = 0
   332  	} else {
   333  		switch state {
   334  		case firstFreeState:
   335  			next := -1 * m.check[state]
   336  			m.check[0] = -1 * next
   337  			m.base[next] = m.base[state]
   338  		case lastFreeState:
   339  			prev := -1 * m.base[state]
   340  			m.base[firstFreeState] = -1 * prev
   341  			m.check[prev] = -1
   342  		default:
   343  			next := -1 * m.check[state]
   344  			prev := -1 * m.base[state]
   345  			m.check[prev] = -1 * next
   346  			m.base[next] = -1 * prev
   347  		}
   348  	}
   349  	m.check[state] = parentState
   350  	m.base[state] = leaf
   351  }
   352  
   353  // setFailState sets the output of the fail function for input state. It will
   354  // traverse up the fail states of it's ancestors until it reaches a fail state
   355  // with a transition for offset.
   356  func (m *Matcher) setFailState(state, parentState, offset int) {
   357  	failState := m.fail[parentState]
   358  	for {
   359  		if m.hasEdge(failState, offset) {
   360  			m.fail[state] = m.base[failState] + offset
   361  			break
   362  		}
   363  		if failState == 0 {
   364  			break
   365  		}
   366  		failState = m.fail[failState]
   367  	}
   368  }
   369  
   370  // unionFailOutput unions the output function for failState with the output
   371  // function for state and sets the result as the output function for state.
   372  // This allows us to match substrings, commenting out this body would match
   373  // every word that is not a substring.
   374  func (m *Matcher) unionFailOutput(state, failState int) {
   375  	m.output[state] = append([]SWord{}, m.output[failState]...)
   376  }
   377  
   378  // findBase finds a base value which has free states in the positions that
   379  // correspond to each edge transition in edges. If this does not exist, then
   380  // base and check (and the fail array for consistency) will be extended just
   381  // enough to fit each transition.
   382  // The extension will maintain the bidirectional link of free states.
   383  func (m *Matcher) findBase(edges []byte) int {
   384  	if len(edges) == 0 {
   385  		return leaf
   386  	}
   387  
   388  	min := int(edges[0])
   389  	max := int(edges[len(edges)-1])
   390  	width := max - min
   391  	freeState := m.firstFreeState()
   392  	for freeState != -1 {
   393  		valid := true
   394  		for _, e := range edges[1:] {
   395  			state := freeState + int(e) - min
   396  			if state >= len(m.check) {
   397  				break
   398  			} else if m.check[state] >= 0 {
   399  				valid = false
   400  				break
   401  			}
   402  		}
   403  
   404  		if valid {
   405  			if freeState+width >= len(m.check) {
   406  				m.increaseSize(width - len(m.check) + freeState + 1)
   407  			}
   408  			return freeState - min
   409  		}
   410  
   411  		freeState = m.nextFreeState(freeState)
   412  	}
   413  	freeState = len(m.check)
   414  	m.increaseSize(width + 1)
   415  	return freeState - min
   416  }
   417  
   418  // increaseSize increases the size of base, check, and fail to ensure they
   419  // remain the same size.
   420  // It also sets the default value for these new unoccupied states which form
   421  // bidirectional links to allow fast access to empty states. These
   422  // bidirectional links only pertain to base and check.
   423  //
   424  // Example:
   425  // m:
   426  //
   427  //	base:  [ 5 0 0 ]
   428  //	check: [ 0 0 0 ]
   429  //
   430  // increaseSize(3):
   431  //
   432  //	base:  [ 5  0 0 -5 -3 -4 ]
   433  //	check: [ -3 0 0 -4 -5 -1 ]
   434  //
   435  // increaseSize(3):
   436  //
   437  //	base:  [ 5  0 0 -8 -3 -4 -5 -6 -7]
   438  //	check: [ -3 0 0 -4 -5 -6 -7 -8 -1]
   439  //
   440  // m:
   441  //
   442  //	base:  [ 5 0 0 ]
   443  //	check: [ 0 0 0 ]
   444  //
   445  // increaseSize(1):
   446  //
   447  //	base:  [ 5  0 0 -3 ]
   448  //	check: [ -3 0 0 -1 ]
   449  //
   450  // increaseSize(1):
   451  //
   452  //	base:  [ 5  0 0 -4 -3 ]
   453  //	check: [ -3 0 0 -4 -1 ]
   454  //
   455  // increaseSize(1):
   456  //
   457  //	base:  [ 5  0 0 -5 -3 -4 ]
   458  //	check: [ -3 0 0 -4 -5 -1 ]
   459  func (m *Matcher) increaseSize(dsize int) {
   460  	if dsize == 0 {
   461  		return
   462  	}
   463  
   464  	m.base = append(m.base, make([]int, dsize)...)
   465  	m.check = append(m.check, make([]int, dsize)...)
   466  	m.fail = append(m.fail, make([]int, dsize)...)
   467  	m.output = append(m.output, make([][]SWord, dsize)...)
   468  
   469  	lastFreeState := m.lastFreeState()
   470  	firstFreeState := m.firstFreeState()
   471  	for i := len(m.check) - dsize; i < len(m.check); i++ {
   472  		if lastFreeState == -1 {
   473  			m.check[0] = -1 * i
   474  			m.base[i] = -1 * i
   475  			m.check[i] = -1
   476  			firstFreeState = i
   477  			lastFreeState = i
   478  		} else {
   479  			m.base[i] = -1 * lastFreeState
   480  			m.check[i] = -1
   481  			m.base[firstFreeState] = -1 * i
   482  			m.check[lastFreeState] = -1 * i
   483  			lastFreeState = i
   484  		}
   485  	}
   486  }
   487  
   488  // nextFreeState uses the nature of the bidirectional link to determine the
   489  // closest free state at a larger index. Since the check array holds the
   490  // negative index of the next free state, except for the last free state which
   491  // has a value of -1, negating this value is the next free state.
   492  func (m *Matcher) nextFreeState(curFreeState int) int {
   493  	nextState := -1 * m.check[curFreeState]
   494  
   495  	// state 1 can never be a free state.
   496  	if nextState == 1 {
   497  		return -1
   498  	}
   499  
   500  	return nextState
   501  }
   502  
   503  // firstFreeState uses the first value in the check array which points to the
   504  // first free state. A value of 0 means there are no free states and -1 is
   505  // returned.
   506  func (m *Matcher) firstFreeState() int {
   507  	state := m.check[0]
   508  	if state != 0 {
   509  		return -1 * state
   510  	}
   511  	return -1
   512  }
   513  
   514  // lastFreeState uses the base value of the first free state which points the
   515  // last free state.
   516  func (m *Matcher) lastFreeState() int {
   517  	firstFree := m.firstFreeState()
   518  	if firstFree != -1 {
   519  		return -1 * m.base[firstFree]
   520  	}
   521  	return -1
   522  }
   523  
   524  // hasEdge determines if the fromState has a transition for offset.
   525  func (m *Matcher) hasEdge(fromState, offset int) bool {
   526  	toState := m.base[fromState] + offset
   527  	return toState > 0 && toState < len(m.check) && m.check[toState] == fromState
   528  }
   529  
   530  // Match represents a matched pattern in the text
   531  type Match struct {
   532  	Word  []byte // the matched pattern
   533  	Index int    // the start index of the match
   534  }
   535  
   536  type Matches interface {
   537  	Append(key int, position int)
   538  	Count() int
   539  }
   540  
   541  func (m *Matcher) findAll(text []byte) []*Match {
   542  	var matches []*Match
   543  	state := 0
   544  	for i, b := range text {
   545  		offset := int(b)
   546  		for state != 0 && !m.hasEdge(state, offset) {
   547  			state = m.fail[state]
   548  		}
   549  
   550  		if m.hasEdge(state, offset) {
   551  			state = m.base[state] + offset
   552  		}
   553  		for _, item := range m.output[state] {
   554  			matches = append(matches, &Match{text[i-int(item.Len)+1 : i+1], i - int(item.Len) + 1})
   555  		}
   556  	}
   557  	return matches
   558  }
   559  
   560  func (m *Matcher) findAllReader(reader io.Reader, matches Matches) {
   561  	state := 0
   562  	buf := make([]byte, 1)
   563  	n, err := reader.Read(buf)
   564  	b := int(buf[0])
   565  	i := 1
   566  	for err == nil && n == 1 {
   567  		offset := b
   568  		for state != 0 && !m.hasEdge(state, offset) {
   569  			state = m.fail[state]
   570  		}
   571  
   572  		if m.hasEdge(state, offset) {
   573  			state = m.base[state] + offset
   574  		}
   575  		for _, item := range m.output[state] {
   576  			matches.Append(i, int(item.Key))
   577  		}
   578  		n, err = reader.Read(buf)
   579  		b = int(buf[0])
   580  		i++
   581  	}
   582  }
   583  
   584  // FindAllByteSlice finds all instances of the patterns in the text.
   585  func (m *Matcher) FindAllByteSlice(text []byte) (matches []*Match) {
   586  	return m.findAll(text)
   587  }
   588  
   589  func (m *Matcher) FindAllByteReader(reader io.Reader, matches Matches) {
   590  	m.findAllReader(reader, matches)
   591  }
   592  
   593  // FindAllString finds all instances of the patterns in the text.
   594  func (m *Matcher) FindAllString(text string) []*Match {
   595  	return m.FindAllByteSlice([]byte(text))
   596  }