github.com/cilium/cilium@v1.16.2/pkg/container/bitlpm/cidr.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package bitlpm
     5  
     6  import (
     7  	"math/bits"
     8  	"net/netip"
     9  	"unsafe"
    10  )
    11  
    12  // CIDRTrie can hold both IPv4 and IPv6 prefixes
    13  // at the same time.
    14  type CIDRTrie[T any] struct {
    15  	v4 Trie[Key[netip.Prefix], T]
    16  	v6 Trie[Key[netip.Prefix], T]
    17  }
    18  
    19  // NewCIDRTrie creates a new CIDRTrie[T any].
    20  func NewCIDRTrie[T any]() *CIDRTrie[T] {
    21  	return &CIDRTrie[T]{
    22  		v4: NewTrie[netip.Prefix, T](32),
    23  		v6: NewTrie[netip.Prefix, T](128),
    24  	}
    25  }
    26  
    27  // ExactLookup returns the value for a given CIDR, but only
    28  // if there is an exact match for the CIDR in the Trie.
    29  func (c *CIDRTrie[T]) ExactLookup(cidr netip.Prefix) (T, bool) {
    30  	return c.treeForFamily(cidr).ExactLookup(uint(cidr.Bits()), cidrKey(cidr))
    31  }
    32  
    33  // LongestPrefixMatch returns the longest matched value for a given address.
    34  func (c *CIDRTrie[T]) LongestPrefixMatch(addr netip.Addr) (T, bool) {
    35  	if !addr.IsValid() {
    36  		var def T
    37  		return def, false
    38  	}
    39  	bits := addr.BitLen()
    40  	prefix := netip.PrefixFrom(addr, bits)
    41  	return c.treeForFamily(prefix).LongestPrefixMatch(cidrKey(prefix))
    42  }
    43  
    44  // Ancestors iterates over every CIDR pair that contains the CIDR argument.
    45  func (c *CIDRTrie[T]) Ancestors(cidr netip.Prefix, fn func(k netip.Prefix, v T) bool) {
    46  	c.treeForFamily(cidr).Ancestors(uint(cidr.Bits()), cidrKey(cidr), func(prefix uint, k Key[netip.Prefix], v T) bool {
    47  		return fn(k.Value(), v)
    48  	})
    49  }
    50  
    51  // Descendants iterates over every CIDR that is contained by the CIDR argument.
    52  func (c *CIDRTrie[T]) Descendants(cidr netip.Prefix, fn func(k netip.Prefix, v T) bool) {
    53  	c.treeForFamily(cidr).Descendants(uint(cidr.Bits()), cidrKey(cidr), func(prefix uint, k Key[netip.Prefix], v T) bool {
    54  		return fn(k.Value(), v)
    55  	})
    56  }
    57  
    58  // Upsert adds or updates the value for a given prefix.
    59  func (c *CIDRTrie[T]) Upsert(cidr netip.Prefix, v T) {
    60  	c.treeForFamily(cidr).Upsert(uint(cidr.Bits()), cidrKey(cidr), v)
    61  }
    62  
    63  // Delete removes a given prefix from the tree.
    64  func (c *CIDRTrie[T]) Delete(cidr netip.Prefix) bool {
    65  	return c.treeForFamily(cidr).Delete(uint(cidr.Bits()), cidrKey(cidr))
    66  }
    67  
    68  // Len returns the total number of ipv4 and ipv6 prefixes in the trie.
    69  func (c *CIDRTrie[T]) Len() uint {
    70  	return c.v4.Len() + c.v6.Len()
    71  }
    72  
    73  // ForEach iterates over every element of the Trie. It iterates over IPv4
    74  // keys first.
    75  func (c *CIDRTrie[T]) ForEach(fn func(k netip.Prefix, v T) bool) {
    76  	var v4Break bool
    77  	c.v4.ForEach(func(prefix uint, k Key[netip.Prefix], v T) bool {
    78  		if !fn(k.Value(), v) {
    79  			v4Break = true
    80  			return false
    81  		}
    82  		return true
    83  	})
    84  	if !v4Break {
    85  		c.v6.ForEach(func(prefix uint, k Key[netip.Prefix], v T) bool {
    86  			return fn(k.Value(), v)
    87  		})
    88  	}
    89  
    90  }
    91  
    92  func (c *CIDRTrie[T]) treeForFamily(cidr netip.Prefix) Trie[Key[netip.Prefix], T] {
    93  	if cidr.Addr().Is6() {
    94  		return c.v6
    95  	}
    96  	return c.v4
    97  }
    98  
    99  type cidrKey netip.Prefix
   100  
   101  func (k cidrKey) Value() netip.Prefix {
   102  	return netip.Prefix(k)
   103  }
   104  
   105  func (k cidrKey) BitValueAt(idx uint) uint8 {
   106  	addr := netip.Prefix(k).Addr()
   107  	if addr.Is4() {
   108  		word := (*(*[2]uint64)(unsafe.Pointer(&addr)))[1]
   109  		return uint8((word >> (31 - idx)) & 1)
   110  	}
   111  	if idx < 64 {
   112  		word := (*(*[2]uint64)(unsafe.Pointer(&addr)))[0]
   113  		return uint8((word >> (63 - idx)) & 1)
   114  	} else {
   115  		word := (*(*[2]uint64)(unsafe.Pointer(&addr)))[1]
   116  		return uint8((word >> (127 - idx)) & 1)
   117  	}
   118  }
   119  
   120  func (k cidrKey) CommonPrefix(k2 netip.Prefix) uint {
   121  	addr1 := netip.Prefix(k).Addr()
   122  	addr2 := k2.Addr()
   123  	words1 := (*[2]uint64)(unsafe.Pointer(&addr1))
   124  	words2 := (*[2]uint64)(unsafe.Pointer(&addr2))
   125  	if addr1.Is4() {
   126  		word1 := uint32((*words1)[1])
   127  		word2 := uint32((*words2)[1])
   128  		return uint(bits.LeadingZeros32(word1 ^ word2))
   129  	}
   130  	v := bits.LeadingZeros64((*words1)[0] ^ (*words2)[0])
   131  	if v == 64 {
   132  		v += bits.LeadingZeros64((*words1)[1] ^ (*words2)[1])
   133  	}
   134  	return uint(v)
   135  }