github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/wordsearch/trie-compact/compact.go (about)

     1  package trie
     2  
     3  import (
     4  	"unsafe"
     5  )
     6  
     7  const (
     8  	termbit   = 1
     9  	maxsuffix = 2
    10  )
    11  
    12  type cindex uint32
    13  
    14  type Compact struct {
    15  	nodes []cnode
    16  	root  cnode
    17  }
    18  
    19  type cnode struct {
    20  	label   byte
    21  	flags   byte // (edge count << 1) | term
    22  	suffix  [maxsuffix]byte
    23  	edgeptr cindex
    24  }
    25  
    26  func (ut *Uncompact) Compress() *Compact {
    27  	t := &Compact{}
    28  	t.nodes = make([]cnode, 0, ut.root.count())
    29  	t.compress(&t.root, &ut.root)
    30  	return t
    31  }
    32  
    33  func (t *Compact) compress(dst *cnode, src *unode) {
    34  	dst.label = src.label
    35  	dst.flags = byte(len(src.edges) << 1)
    36  	if src.term {
    37  		dst.flags |= termbit
    38  	}
    39  	copy(dst.suffix[:], []byte(src.suffix))
    40  
    41  	dst.edgeptr = cindex(len(t.nodes))
    42  	t.nodes = t.nodes[: len(t.nodes)+len(src.edges) : cap(t.nodes)]
    43  	for i := range src.edges {
    44  		t.compress(t.edge(dst, i), &src.edges[i])
    45  	}
    46  }
    47  
    48  func (t *Compact) ContainsBytes(s []byte) bool {
    49  	node := &t.root
    50  next:
    51  	for i, b := range s {
    52  		n := node.edgecount()
    53  		if node.suffix[0] != 0 && len(s)-i <= maxsuffix {
    54  			if node.suffixMatchBytes(s[i:]) {
    55  				return true
    56  			}
    57  		}
    58  		for i := 0; i < n; i++ {
    59  			child := t.edge(node, i)
    60  			if child.label == b {
    61  				node = child
    62  				continue next
    63  			}
    64  		}
    65  		return false
    66  	}
    67  	return node.terminates()
    68  }
    69  func (n *cnode) suffixMatchBytes(s []byte) bool {
    70  	for i, b := range n.suffix {
    71  		if b == 0 {
    72  			return i == len(s)
    73  		}
    74  		if i >= len(s) {
    75  			return false
    76  		}
    77  		if b != s[i] {
    78  			return false
    79  		}
    80  	}
    81  	return true
    82  }
    83  
    84  func (t *Compact) Contains(s string) bool {
    85  	node := &t.root
    86  next:
    87  	for i := 0; i < len(s); i++ {
    88  		b := s[i]
    89  		n := node.edgecount()
    90  		if node.suffix[0] != 0 && len(s)-i <= maxsuffix {
    91  			if node.suffixMatch(s[i:]) {
    92  				return true
    93  			}
    94  		}
    95  		for i := 0; i < n; i++ {
    96  			child := t.edge(node, i)
    97  			if child.label == b {
    98  				node = child
    99  				continue next
   100  			}
   101  		}
   102  		return false
   103  	}
   104  	return node.terminates()
   105  }
   106  func (n *cnode) suffixMatch(s string) bool {
   107  	for i, b := range n.suffix {
   108  		if b == 0 {
   109  			return i == len(s)
   110  		}
   111  		if i >= len(s) {
   112  			return false
   113  		}
   114  		if b != s[i] {
   115  			return false
   116  		}
   117  	}
   118  	return true
   119  }
   120  func (t *Compact) edge(n *cnode, i int) *cnode { return &t.nodes[int(n.edgeptr)+i] }
   121  
   122  func (n *cnode) edgecount() int   { return int(n.flags >> 1) }
   123  func (n *cnode) terminates() bool { return n.flags&termbit == termbit }
   124  
   125  // debugging
   126  
   127  func (t *Compact) Size() int { return int(unsafe.Sizeof(cnode{})) * len(t.nodes) }
   128  
   129  func (t *Compact) NodeCount() int { return len(t.nodes) }
   130  
   131  func (t *Compact) MaxOffset() int {
   132  	max := 0
   133  	for i := range t.nodes {
   134  		dist := int(t.nodes[i].edgeptr) - i
   135  		if dist > max {
   136  			max = dist
   137  		}
   138  	}
   139  	return max
   140  }
   141  
   142  func (t *Compact) MaxEdges() int {
   143  	max := 0
   144  	for i := range t.nodes {
   145  		edgelen := int(t.nodes[i].flags >> 1)
   146  		if edgelen > max {
   147  			max = edgelen
   148  		}
   149  	}
   150  	return max
   151  }
   152  
   153  func safesuffix(s string) bool {
   154  	if len(s) > maxsuffix {
   155  		return false
   156  	}
   157  	for i := 0; i < len(s); i++ {
   158  		if s[i] == 0 {
   159  			return false
   160  		}
   161  	}
   162  	return true
   163  }