github.com/cilium/cilium@v1.16.2/pkg/container/bitlpm/trie.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package bitlpm
     5  
     6  // Trie is a [non-preemptive] [binary] [trie] that indexes arbitrarily long
     7  // bit-based keys with associated prefix lengths indexed from [most significant bit]
     8  // ("MSB") to [least significant bit] ("LSB") using the
     9  // [longest prefix match algorithm].
    10  //
    11  // A prefix-length (hereafter "prefix"), in a prefix-key pair, represents the
    12  // minimum number of bits (from MSB to LSB) that another comparable key
    13  // must match.
    14  //
    15  // Each method's comments describes the mechanism of how the method
    16  // works.
    17  //
    18  // [non-preemptive]: https://en.wikipedia.org/wiki/Preemption_(computing)
    19  // [binary]: https://en.wikipedia.org/wiki/Binary_number
    20  // [trie]: https://en.wikipedia.org/wiki/Trie
    21  // [most significant bit]: https://en.wikipedia.org/wiki/Bit_numbering#Most_significant_bit
    22  // [least significant bit]: https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit
    23  // [longest prefix match algorithm]: https://en.wikipedia.org/wiki/Longest_prefix_match
    24  type Trie[K, T any] interface {
    25  	// ExactLookup returns a value only if the prefix and key
    26  	// match an entry in the Trie exactly.
    27  	//
    28  	// Note: If the prefix argument exceeds the Trie's maximum
    29  	// prefix, it will be set to the Trie's maximum prefix.
    30  	ExactLookup(prefix uint, key K) (v T, ok bool)
    31  	// LongestPrefixMatch returns the longest prefix match for a specific
    32  	// key.
    33  	LongestPrefixMatch(key K) (v T, ok bool)
    34  	// Ancestors iterates over every prefix-key pair that contains
    35  	// the prefix-key argument pair. If the Ancestors function argument
    36  	// returns false the iteration will stop. Ancestors will iterate
    37  	// keys from shortest to longest prefix match (that is, the
    38  	// longest match will be returned last).
    39  	//
    40  	// Note: If the prefix argument exceeds the Trie's maximum
    41  	// prefix, it will be set to the Trie's maximum prefix.
    42  	Ancestors(prefix uint, key K, fn func(uint, K, T) bool)
    43  	// Descendants iterates over every prefix-key pair that is contained
    44  	// by the prefix-key argument pair. If the Descendants function argument
    45  	// returns false the iteration will stop. Descendants does **not** iterate
    46  	// over matches in any guaranteed order.
    47  	//
    48  	// Note: If the prefix argument exceeds the Trie's maximum
    49  	// prefix, it will be set to the Trie's maximum prefix.
    50  	Descendants(prefix uint, key K, fn func(uint, K, T) bool)
    51  	// Upsert updates or inserts the trie with a a prefix, key,
    52  	// and value.
    53  	//
    54  	// Note: If the prefix argument exceeds the Trie's maximum
    55  	// prefix, it will be set to the Trie's maximum prefix.
    56  	Upsert(prefix uint, key K, value T)
    57  	// Delete removes a key with the exact given prefix and returns
    58  	// false if the key was not found.
    59  	//
    60  	// Note: If the prefix argument exceeds the Trie's maximum
    61  	// prefix, it will be set to the Trie's maximum prefix.
    62  	Delete(prefix uint, key K) bool
    63  	// Len returns the number of entries in the Trie
    64  	Len() uint
    65  	// ForEach iterates over every element of the Trie in no particular
    66  	// order. If the function argument returns false the iteration stops.
    67  	ForEach(fn func(uint, K, T) bool)
    68  }
    69  
    70  // Key is an interface that implements all the necessary
    71  // methods to index and retrieve keys.
    72  type Key[K any] interface {
    73  	// CommonPrefix returns the number of bits that
    74  	// are the same between this key and the argument
    75  	// value, starting from MSB.
    76  	CommonPrefix(K) uint
    77  	// BitValueAt returns the value of the bit at an argument
    78  	// index. MSB is 0 and LSB is n-1.
    79  	BitValueAt(uint) uint8
    80  	// Value returns the underlying value of the Key.
    81  	Value() K
    82  }
    83  
    84  // trie is the generic implementation of a bit-trie that can
    85  // accept arbitrary keys conforming to the Key[K] interface.
    86  type trie[K, T any] struct {
    87  	root      *node[K, T]
    88  	maxPrefix uint
    89  	entries   uint
    90  }
    91  
    92  // NewTrie returns a Trie that accepts the Key[K any] interface
    93  // as its key argument. This enables the user of this Trie to
    94  // define their own bit-key.
    95  func NewTrie[K, T any](maxPrefix uint) Trie[Key[K], T] {
    96  	return &trie[K, T]{
    97  		maxPrefix: maxPrefix,
    98  	}
    99  }
   100  
   101  // node represents a specific key and prefix in the trie
   102  type node[K, T any] struct {
   103  	children     [2]*node[K, T]
   104  	prefixLen    uint
   105  	key          Key[K]
   106  	intermediate bool
   107  	value        T
   108  }
   109  
   110  // ExactLookup returns a value only if the prefix and key
   111  // match an entry in the Trie exactly.
   112  //
   113  // Note: If the prefix argument exceeds the Trie's maximum
   114  // prefix, it will be set to the Trie's maximum prefix.
   115  func (t *trie[K, T]) ExactLookup(prefixLen uint, k Key[K]) (ret T, found bool) {
   116  	prefixLen = min(prefixLen, t.maxPrefix)
   117  	t.traverse(prefixLen, k, func(currentNode *node[K, T], matchLen uint) bool {
   118  		// Only copy node value if exact prefix length is found
   119  		if matchLen == prefixLen {
   120  			ret = currentNode.value
   121  			found = true
   122  			return false // no need to continue
   123  		}
   124  		return true
   125  	})
   126  	return ret, found
   127  }
   128  
   129  // LongestPrefixMatch returns the value for the key with the
   130  // longest prefix match of the argument key.
   131  func (t *trie[K, T]) LongestPrefixMatch(k Key[K]) (T, bool) {
   132  	// default return value
   133  	var (
   134  		empty T
   135  		ok    bool
   136  	)
   137  	ret := &empty
   138  	t.traverse(t.maxPrefix, k, func(currentNode *node[K, T], matchLen uint) bool {
   139  		ret = &currentNode.value
   140  		ok = true
   141  		return true
   142  	})
   143  	return *ret, ok
   144  }
   145  
   146  // Ancestors calls the function argument for every prefix/key/value in the trie
   147  // that contains the prefix-key argument pair in order from shortest to longest
   148  // prefix match. If the function argument returns false the iteration stops.
   149  //
   150  // Note: Ancestors sets any prefixLen argument that exceeds the maximum
   151  // prefix allowed by the trie to the maximum prefix allowed by the
   152  // trie.
   153  func (t *trie[K, T]) Ancestors(prefixLen uint, k Key[K], fn func(prefix uint, key Key[K], value T) bool) {
   154  	prefixLen = min(prefixLen, t.maxPrefix)
   155  	t.traverse(prefixLen, k, func(currentNode *node[K, T], matchLen uint) bool {
   156  		return fn(currentNode.prefixLen, currentNode.key, currentNode.value)
   157  	})
   158  }
   159  
   160  // Descendants calls the function argument for every prefix/key/value in the
   161  // trie that is contained by the prefix-key argument pair. If the function
   162  // argument returns false the iteration stops. Descendants does **not** iterate
   163  // over matches in any guaranteed order.
   164  //
   165  // Note: Descendants sets any prefixLen argument that exceeds the maximum
   166  // prefix allowed by the trie to the maximum prefix allowed by the
   167  // trie.
   168  func (t *trie[K, T]) Descendants(prefixLen uint, k Key[K], fn func(prefix uint, key Key[K], value T) bool) {
   169  	if k == nil {
   170  		return
   171  	}
   172  	prefixLen = min(prefixLen, t.maxPrefix)
   173  	currentNode := t.root
   174  	for currentNode != nil {
   175  		matchLen := currentNode.prefixMatch(prefixLen, k)
   176  		// CurrentNode matches the prefix-key argument
   177  		if matchLen >= prefixLen {
   178  			currentNode.forEach(fn)
   179  			return
   180  		}
   181  		// currentNode is a leaf and has no children. Calling k.BitValueAt may
   182  		// overrun the key storage.
   183  		if currentNode.prefixLen >= t.maxPrefix {
   184  			return
   185  		}
   186  		currentNode = currentNode.children[k.BitValueAt(currentNode.prefixLen)]
   187  	}
   188  }
   189  
   190  // traverse iterates over every prefix-key pair that contains the
   191  // prefix-key argument pair in order from shortest to longest prefix
   192  // match. If the function argument returns false the iteration will stop.
   193  //
   194  // traverse starts at the root node in the trie.
   195  // The key and prefix being searched (the "search" key and prefix) are
   196  // compared to the a trie node's key and prefix (the "node" key and
   197  // prefix) to determine the extent to which the keys match (from MSB to
   198  // LSB) up to the **least** specific (or shortest) prefix of the two keys
   199  // (for example, if one of the keys has a prefix length of 2 and the other has
   200  // a prefix length of 3 then the two keys will be compared up to the 2nd bit).
   201  // If the key's match less than the node prefix (that is, the search
   202  // key did not fully match the node key) then the traversal ends.
   203  // If the key's match was greater than or equal to the node prefix
   204  // then the node key is iterated over as a potential match,
   205  // but traversal continues to ensure that there is not a more specific
   206  // (that is, longer) match. The next bit, after the match length (between
   207  // the search key and node key), on the search key is looked up to
   208  // determine which children of the current node to traverse (to
   209  // check if there is a more specific match). If there is no child then
   210  // traversal ends. Otherwise traversal continues.
   211  func (t *trie[K, T]) traverse(prefixLen uint, k Key[K], fn func(currentNode *node[K, T], matchLen uint) bool) {
   212  	if k == nil {
   213  		return
   214  	}
   215  	for currentNode := t.root; currentNode != nil; currentNode = currentNode.children[k.BitValueAt(currentNode.prefixLen)] {
   216  		matchLen := currentNode.prefixMatch(prefixLen, k)
   217  		// The current-node does not match.
   218  		if matchLen < currentNode.prefixLen {
   219  			return
   220  		}
   221  		// Skip over intermediate nodes
   222  		if currentNode.intermediate {
   223  			continue
   224  		}
   225  		if !fn(currentNode, matchLen) || matchLen == t.maxPrefix {
   226  			return
   227  		}
   228  	}
   229  }
   230  
   231  // Upsert inserts or replaces a key and prefix (an "upsert" key and
   232  // prefix) below keys that match it with a smaller (that is, less
   233  // specific) prefix and above keys that match it with a
   234  // more specific (that is "higher") prefix.
   235  //
   236  // Upsert starts with the root key (or "node"). The upsert key and node
   237  // key are compared for the match length between them (see the
   238  // `traverse` comments for details on how this works). If the match
   239  // length is exactly equal to the node prefix then traversal
   240  // continues as the next bit after the match length in the upsert key
   241  // corresponds to one of the two child slots that belong to the node
   242  // key. If the match length is not exactly equal, or there is no child
   243  // to traverse to, or the node prefix is exactly equal to the
   244  // upsert prefix (these conditions are not mutually exclusive) then traversal
   245  // is finished. There are four possible insertion/replacement condtions
   246  // to consider:
   247  //  1. The node key is nil (that is, an empty children "slot"), in which
   248  //     case the previous key iterated over should be the upsert-key's
   249  //     parent. If there is no parent then the node key is now the
   250  //     root node.
   251  //  2. The node key matches the upsert-node to the exact
   252  //     prefix. Then the upsert key should replace the node key.
   253  //  3. The node key matches the upsert key to the upsert prefix,
   254  //     but node prefix is greater than the upsert prefix. In this
   255  //     case the node key will become a child of the upsert key.
   256  //  4. The node key does not match with the upsert key to either
   257  //     the node prefix or the upsert prefix. In this case an
   258  //     intermediate node needs to be inserted that replaces the
   259  //     current position of the node key, but give it a prefix
   260  //     of the match between the upsert key and node key. The
   261  //     node key and upsert key become siblings.
   262  //
   263  // Intermediate keys/nodes:
   264  // Sometimes when a new key is inserted it does not match any key up to
   265  // its own prefix or its closest matching key's prefix. When this
   266  // happens an intermediate node with the common prefix of the upsert
   267  // key and closest match key. The new intermediate key replaces the closest
   268  // match key's position in the trie and takes the closest match key and
   269  // upsert key as children.
   270  //
   271  // For example, assuming a key size of 8 bytes, adding the prefix-keys of
   272  // "0b001/8"(1-1), "0b010/7"(2-3), and "0b100/6"(4-7) would follow this logic:
   273  //
   274  //  1. "0b001/8" gets added first. It becomes the root node.
   275  //  2. "0b010/7" is added. It will match "0b001/8" (the root node) up to
   276  //     6 bits, because "0b010/7"'s 7th bit is 1 and "0b001/8" has 7th bit of 0.
   277  //     In this case, an intermediate node "0b001/6" will be created (the extent
   278  //     to which "0b010/7" and "0b001/8" match). The new intermediate node will
   279  //     have children "0b001/8" (in the 0 slot) and "0b010/7" (in the 1 slot).
   280  //     This new intermediate node become the new root node.
   281  //  3. When "0b100/6" is added it will match the new root (which happens to
   282  //     be an intermediate node) "0b001/6" up to 5 bits. Therefore another
   283  //     intermediate node of "0b001/5" will be created, becoming the new root
   284  //     node. "0b001/6" will become the new intermediate node's child in the
   285  //     0 slot and "0b100/6" will become the other child in the 1 slot.
   286  //     "0b001/5" becomes the new root node.
   287  //
   288  // Note: Upsert sets any "prefixLen" argument that exceeds the maximum
   289  // prefix allowed by the trie to the maximum prefix allowed by the
   290  // trie.
   291  func (t *trie[K, T]) Upsert(prefixLen uint, k Key[K], value T) {
   292  	if k == nil {
   293  		return
   294  	}
   295  	prefixLen = min(prefixLen, t.maxPrefix)
   296  	upsertNode := &node[K, T]{
   297  		prefixLen: prefixLen,
   298  		key:       k,
   299  		value:     value,
   300  	}
   301  
   302  	var (
   303  		matchLen uint
   304  		parent   *node[K, T]
   305  		bitVal   uint8
   306  	)
   307  
   308  	currentNode := t.root
   309  	for currentNode != nil {
   310  		matchLen = currentNode.prefixMatch(prefixLen, k)
   311  		// The current node does not match the upsert-{prefix,key}
   312  		// or the current node matches to the maximum extent
   313  		// allowable by either the trie or the upsert-prefix.
   314  		if currentNode.prefixLen != matchLen ||
   315  			currentNode.prefixLen == t.maxPrefix ||
   316  			currentNode.prefixLen == prefixLen {
   317  			break
   318  		}
   319  		bitVal = k.BitValueAt(currentNode.prefixLen)
   320  		parent = currentNode
   321  		currentNode = currentNode.children[bitVal]
   322  	}
   323  	t.entries++
   324  	// Empty slot.
   325  	if currentNode == nil {
   326  		if parent == nil {
   327  			t.root = upsertNode
   328  		} else {
   329  			parent.children[bitVal] = upsertNode
   330  		}
   331  		return
   332  	}
   333  	// There are three cases:
   334  	// 1. The current-node matches the upsert-node to the exact
   335  	//    prefix. Then the upsert-node should replace the current-node.
   336  	// 2. The current-node matches the upsert-node, but the
   337  	//    current-node has a more specific prefix than the
   338  	//    upsert-node. Then the current-node should become a child
   339  	//    of the upsert-node.
   340  	// 3. The current-node does not match with the upsert-node,
   341  	//    but they overlap. Then a new intermediate-node should replace
   342  	//    the current-node with a prefix equal to the overlap.
   343  	//    The current-node and the upsert-node become children
   344  	//    of the new intermediate node.
   345  	//
   346  	//    For example, given two keys, "current" and "upsert":
   347  	//        current: 0b1010/4
   348  	//        upsert:  0b1000/3
   349  	//    A new key of "0b1010/2" would then be added as an intermediate key
   350  	//    (note: the 3rd bit does not matter, but unsetting is an extra
   351  	//    operation that we avoid). "current" would be a child of
   352  	//    intermediate at index "1" and "upsert" would be at index "0".
   353  
   354  	// The upsert-node matches the current-node up to the
   355  	// current-node's prefix, replace the current-node.
   356  	if matchLen == currentNode.prefixLen {
   357  		if parent == nil {
   358  			t.root = upsertNode
   359  		} else {
   360  			parent.children[bitVal] = upsertNode
   361  		}
   362  		// If we're not replacing an intermediate node
   363  		// then decrement this function's previous
   364  		// increment of `entries`.
   365  		if !currentNode.intermediate {
   366  			t.entries--
   367  		}
   368  		upsertNode.children[0] = currentNode.children[0]
   369  		upsertNode.children[1] = currentNode.children[1]
   370  		return
   371  	}
   372  
   373  	// The upsert-node matches the current-node up to
   374  	// the upsert-node's prefix, make the current-node
   375  	// a child of the upsert-node.
   376  	if matchLen == prefixLen {
   377  		if parent == nil {
   378  			t.root = upsertNode
   379  		} else {
   380  			parent.children[bitVal] = upsertNode
   381  		}
   382  		bitVal = currentNode.key.BitValueAt(matchLen)
   383  		upsertNode.children[bitVal] = currentNode
   384  		return
   385  	}
   386  	// The upsert-node does not match the current-node
   387  	// up to the upsert-node's prefix and the current-node
   388  	// does not match the upsert-node up to the
   389  	// current-node's prefix, make the nodes siblings with
   390  	// an intermediate node.
   391  	intermediateNode := &node[K, T]{
   392  		prefixLen:    matchLen,
   393  		key:          currentNode.key,
   394  		intermediate: true,
   395  	}
   396  	if parent == nil {
   397  		t.root = intermediateNode
   398  	} else {
   399  		parent.children[bitVal] = intermediateNode
   400  	}
   401  	if k.BitValueAt(matchLen) == 0 {
   402  		intermediateNode.children[0] = upsertNode
   403  		intermediateNode.children[1] = currentNode
   404  	} else {
   405  		intermediateNode.children[0] = currentNode
   406  		intermediateNode.children[1] = upsertNode
   407  	}
   408  }
   409  
   410  // Delete deletes only keys that match the exact values of the
   411  // prefix length and key arguments.
   412  //
   413  // Delete traverses the trie until it either finds a node key
   414  // that does not match the delete key to the node key's prefix
   415  // (a definitive non-match) or the node key's prefix is equal
   416  // to the delete prefix (a potential deletion). If the delete prefix,
   417  // node prefix, and match length between the keys are equal to
   418  // the same value then the key is deleted from the trie.
   419  //
   420  // Note: Delete sets any prefixLen argument that exceeds the maximum
   421  // prefix allowed by the trie to the maximum prefix allowed by the
   422  // trie.
   423  func (t *trie[K, T]) Delete(prefixLen uint, k Key[K]) bool {
   424  	if k == nil {
   425  		return false
   426  	}
   427  	prefixLen = min(prefixLen, t.maxPrefix)
   428  
   429  	var (
   430  		grandParent, parent *node[K, T]
   431  		matchLen            uint
   432  		bitVal, prevBitVal  uint8
   433  	)
   434  
   435  	currentNode := t.root
   436  	for currentNode != nil {
   437  		// Find to what extent the current node matches with the
   438  		// delete-{prefix,key}.
   439  		matchLen = currentNode.prefixMatch(prefixLen, k)
   440  		// The current-node does not match or it has the same
   441  		// prefix length (the only potential deletion in the
   442  		// trie).
   443  		if currentNode.prefixLen != matchLen ||
   444  			currentNode.prefixLen == prefixLen {
   445  			break
   446  		}
   447  		prevBitVal = bitVal
   448  		bitVal = k.BitValueAt(currentNode.prefixLen)
   449  		// We preserve the grandParent in order
   450  		// to prune intermediate nodes when they
   451  		// are no longer necessary.
   452  		grandParent = parent
   453  		parent = currentNode
   454  		currentNode = currentNode.children[bitVal]
   455  	}
   456  	// Not found, or the current-node does not match
   457  	// the delete-prefix exactly, or the current-node
   458  	// does not match the delete-{prefix,key} lookup,
   459  	// or the current-node is intermediate.
   460  	if currentNode == nil ||
   461  		currentNode.prefixLen != prefixLen ||
   462  		currentNode.prefixLen != matchLen ||
   463  		currentNode.intermediate {
   464  		return false
   465  	}
   466  	t.entries--
   467  
   468  	// If this node has two children, we need to keep it as an intermediate
   469  	// node because we cannot migrate both children up the trie.
   470  	if currentNode.children[0] != nil && currentNode.children[1] != nil {
   471  		var emptyT T
   472  		currentNode.intermediate = true
   473  		// Make sure that the value associated with this intermediate
   474  		// node can be GC'd.
   475  		currentNode.value = emptyT
   476  		return true
   477  	}
   478  
   479  	// If the parent of the current-node to be deleted is an
   480  	// intermediate-node and the current-node has no children
   481  	// then the parent (intermediate) node can be deleted and
   482  	// its other child promoted up the trie.
   483  	if parent != nil && parent.intermediate &&
   484  		currentNode.children[0] == nil && currentNode.children[1] == nil {
   485  		var saveNode *node[K, T]
   486  		if k.BitValueAt(parent.prefixLen) == 0 {
   487  			saveNode = parent.children[1]
   488  		} else {
   489  			saveNode = parent.children[0]
   490  		}
   491  		parent.children[0] = nil
   492  		parent.children[1] = nil
   493  		if grandParent == nil {
   494  			t.root = saveNode
   495  		} else {
   496  			grandParent.children[prevBitVal] = saveNode
   497  		}
   498  		return true
   499  	}
   500  
   501  	// migrate the last child (if any) up the trie.
   502  	if currentNode.children[0] != nil {
   503  		currentNode = currentNode.children[0]
   504  	} else if currentNode.children[1] != nil {
   505  		currentNode = currentNode.children[1]
   506  	} else {
   507  		currentNode = nil
   508  	}
   509  	if parent == nil {
   510  		t.root = currentNode
   511  	} else {
   512  		parent.children[bitVal] = currentNode
   513  	}
   514  	return true
   515  }
   516  
   517  func (t *trie[K, T]) Len() uint {
   518  	return t.entries
   519  }
   520  
   521  func (t *trie[K, T]) ForEach(fn func(prefix uint, key Key[K], value T) bool) {
   522  	if t.root != nil {
   523  		t.root.forEach(fn)
   524  	}
   525  }
   526  
   527  // prefixMatch returns the length that the node key and
   528  // the argument key match, with the limit of the match being
   529  // the lesser of the node-key prefix or the argument-key prefix.
   530  func (n *node[K, T]) prefixMatch(prefix uint, k Key[K]) uint {
   531  	limit := min(n.prefixLen, prefix)
   532  	prefixLen := n.key.CommonPrefix(k.Value())
   533  	if prefixLen >= limit {
   534  		return limit
   535  	}
   536  	return prefixLen
   537  }
   538  
   539  // forEach calls the argument function for each key and value in
   540  // the subtree rooted at the current node
   541  func (n *node[K, T]) forEach(fn func(prefix uint, key Key[K], value T) bool) {
   542  	if !n.intermediate {
   543  		if !fn(n.prefixLen, n.key, n.value) {
   544  			return
   545  		}
   546  	}
   547  	if n.children[0] != nil {
   548  		n.children[0].forEach(fn)
   549  	}
   550  	if n.children[1] != nil {
   551  		n.children[1].forEach(fn)
   552  	}
   553  }