github.com/cilium/cilium@v1.16.2/pkg/container/bitlpm/cidr.go (about) 1 // SPDX-License-Identifier: Apache-2.0 2 // Copyright Authors of Cilium 3 4 package bitlpm 5 6 import ( 7 "math/bits" 8 "net/netip" 9 "unsafe" 10 ) 11 12 // CIDRTrie can hold both IPv4 and IPv6 prefixes 13 // at the same time. 14 type CIDRTrie[T any] struct { 15 v4 Trie[Key[netip.Prefix], T] 16 v6 Trie[Key[netip.Prefix], T] 17 } 18 19 // NewCIDRTrie creates a new CIDRTrie[T any]. 20 func NewCIDRTrie[T any]() *CIDRTrie[T] { 21 return &CIDRTrie[T]{ 22 v4: NewTrie[netip.Prefix, T](32), 23 v6: NewTrie[netip.Prefix, T](128), 24 } 25 } 26 27 // ExactLookup returns the value for a given CIDR, but only 28 // if there is an exact match for the CIDR in the Trie. 29 func (c *CIDRTrie[T]) ExactLookup(cidr netip.Prefix) (T, bool) { 30 return c.treeForFamily(cidr).ExactLookup(uint(cidr.Bits()), cidrKey(cidr)) 31 } 32 33 // LongestPrefixMatch returns the longest matched value for a given address. 34 func (c *CIDRTrie[T]) LongestPrefixMatch(addr netip.Addr) (T, bool) { 35 if !addr.IsValid() { 36 var def T 37 return def, false 38 } 39 bits := addr.BitLen() 40 prefix := netip.PrefixFrom(addr, bits) 41 return c.treeForFamily(prefix).LongestPrefixMatch(cidrKey(prefix)) 42 } 43 44 // Ancestors iterates over every CIDR pair that contains the CIDR argument. 45 func (c *CIDRTrie[T]) Ancestors(cidr netip.Prefix, fn func(k netip.Prefix, v T) bool) { 46 c.treeForFamily(cidr).Ancestors(uint(cidr.Bits()), cidrKey(cidr), func(prefix uint, k Key[netip.Prefix], v T) bool { 47 return fn(k.Value(), v) 48 }) 49 } 50 51 // Descendants iterates over every CIDR that is contained by the CIDR argument. 52 func (c *CIDRTrie[T]) Descendants(cidr netip.Prefix, fn func(k netip.Prefix, v T) bool) { 53 c.treeForFamily(cidr).Descendants(uint(cidr.Bits()), cidrKey(cidr), func(prefix uint, k Key[netip.Prefix], v T) bool { 54 return fn(k.Value(), v) 55 }) 56 } 57 58 // Upsert adds or updates the value for a given prefix. 59 func (c *CIDRTrie[T]) Upsert(cidr netip.Prefix, v T) { 60 c.treeForFamily(cidr).Upsert(uint(cidr.Bits()), cidrKey(cidr), v) 61 } 62 63 // Delete removes a given prefix from the tree. 64 func (c *CIDRTrie[T]) Delete(cidr netip.Prefix) bool { 65 return c.treeForFamily(cidr).Delete(uint(cidr.Bits()), cidrKey(cidr)) 66 } 67 68 // Len returns the total number of ipv4 and ipv6 prefixes in the trie. 69 func (c *CIDRTrie[T]) Len() uint { 70 return c.v4.Len() + c.v6.Len() 71 } 72 73 // ForEach iterates over every element of the Trie. It iterates over IPv4 74 // keys first. 75 func (c *CIDRTrie[T]) ForEach(fn func(k netip.Prefix, v T) bool) { 76 var v4Break bool 77 c.v4.ForEach(func(prefix uint, k Key[netip.Prefix], v T) bool { 78 if !fn(k.Value(), v) { 79 v4Break = true 80 return false 81 } 82 return true 83 }) 84 if !v4Break { 85 c.v6.ForEach(func(prefix uint, k Key[netip.Prefix], v T) bool { 86 return fn(k.Value(), v) 87 }) 88 } 89 90 } 91 92 func (c *CIDRTrie[T]) treeForFamily(cidr netip.Prefix) Trie[Key[netip.Prefix], T] { 93 if cidr.Addr().Is6() { 94 return c.v6 95 } 96 return c.v4 97 } 98 99 type cidrKey netip.Prefix 100 101 func (k cidrKey) Value() netip.Prefix { 102 return netip.Prefix(k) 103 } 104 105 func (k cidrKey) BitValueAt(idx uint) uint8 { 106 addr := netip.Prefix(k).Addr() 107 if addr.Is4() { 108 word := (*(*[2]uint64)(unsafe.Pointer(&addr)))[1] 109 return uint8((word >> (31 - idx)) & 1) 110 } 111 if idx < 64 { 112 word := (*(*[2]uint64)(unsafe.Pointer(&addr)))[0] 113 return uint8((word >> (63 - idx)) & 1) 114 } else { 115 word := (*(*[2]uint64)(unsafe.Pointer(&addr)))[1] 116 return uint8((word >> (127 - idx)) & 1) 117 } 118 } 119 120 func (k cidrKey) CommonPrefix(k2 netip.Prefix) uint { 121 addr1 := netip.Prefix(k).Addr() 122 addr2 := k2.Addr() 123 words1 := (*[2]uint64)(unsafe.Pointer(&addr1)) 124 words2 := (*[2]uint64)(unsafe.Pointer(&addr2)) 125 if addr1.Is4() { 126 word1 := uint32((*words1)[1]) 127 word2 := uint32((*words2)[1]) 128 return uint(bits.LeadingZeros32(word1 ^ word2)) 129 } 130 v := bits.LeadingZeros64((*words1)[0] ^ (*words2)[0]) 131 if v == 64 { 132 v += bits.LeadingZeros64((*words1)[1] ^ (*words2)[1]) 133 } 134 return uint(v) 135 }