k8s.io/registry.k8s.io@v0.3.1/pkg/net/cidrs/triemap.go (about) 1 /* 2 Copyright 2022 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package cidrs 18 19 import ( 20 "net/netip" 21 ) 22 23 // TrieMap contains an efficient trie structure of netip.Prefix that can 24 // match a netip.Addr to the associated Prefix if any and return the value 25 // associated with it of type V. 26 // 27 // # Use NewTrieMap to instantiate 28 // 29 // NOTE: This is insert-only (no delete) and insertion is *not* thread-safe. 30 // 31 // Currently this is a simple TrieMap, in the future it may have compression. 32 // 33 // See: https://vincent.bernat.ch/en/blog/2017-ipv4-route-lookup-linux 34 // 35 // For benchmarks with real data see ./aws/mapper_test.go 36 type TrieMap[V comparable] struct { 37 // This is the real triemap, but it only maps netip.Prefix / netip.Addr : int 38 // see: https://planetscale.com/blog/generics-can-make-your-go-code-slower 39 // The maps below map from int in this trie to generic value type V 40 // 41 // This is also cheaper in many cases because int will be smaller than V 42 // so we can store V only once in the map here, and int indexes into those 43 // maps in the trie structure, given than many trie nodes will map to the same 44 // V, as our target use-case is CIDR-to-cloud-region 45 trieMap trieMap 46 47 // simple inline bimap of int keys to V values 48 // 49 // the inner trie stores an int key index into keyToValue 50 // 51 // valueToKey is to cheapen checking if we've already inserted a given V 52 // and use the same key 53 keyToValue map[int]V 54 valueToKey map[V]int 55 } 56 57 // NewTrieMap[V] returns a new, properly allocated TrieMap[V] 58 func NewTrieMap[V comparable]() *TrieMap[V] { 59 return &TrieMap[V]{ 60 keyToValue: make(map[int]V), 61 valueToKey: make(map[V]int), 62 } 63 } 64 65 // Insert inserts value into TrieMap by index cidr 66 // You can later match a netip.Addr to value with GetIP 67 func (t *TrieMap[V]) Insert(cidr netip.Prefix, value V) { 68 key, alreadyHave := t.valueToKey[value] 69 if !alreadyHave { 70 // next key = length of map 71 // this structure is insert-only 72 key = len(t.keyToValue) 73 t.valueToKey[value] = key 74 t.keyToValue[key] = value 75 } 76 t.trieMap.Insert(cidr, key) 77 } 78 79 // GetIP returns the associated value for the matching cidr if any with contains=true, 80 // or else the default value of V and contains=false 81 func (t *TrieMap[V]) GetIP(ip netip.Addr) (value V, contains bool) { 82 // NOTE: this is written so as not to shadow contains locally 83 // and so we can use value as a default-value for V without 84 // another variable, using the name also to document the return 85 key, c := t.trieMap.GetIP(ip) 86 contains = c 87 if !contains { 88 return 89 } 90 value = t.keyToValue[key] 91 return 92 } 93 94 // trieMap is the core implementation, but it only stores netip.Prefix : int 95 type trieMap struct { 96 // surely ipv4 and ipv6 will be enough in our lifetime? 97 ipv4Root *trieNode 98 ipv6Root *trieNode 99 } 100 101 // TODO: path compression 102 type trieNode struct { 103 // children for 0 and 1 bits 104 child0 *trieNode 105 child1 *trieNode 106 // both of these values will be set together or not set 107 // so we place them in a sub struct to save memory at the cost 108 // of chasing one additional pointer per trie node checked 109 value *nodeValue 110 } 111 112 type nodeValue struct { 113 cidr netip.Prefix 114 key int 115 } 116 117 func (t *trieMap) Insert(cidr netip.Prefix, key int) { 118 if cidr.Addr().Is4() { 119 t.insertIPV4(cidr, key) 120 } else { 121 t.insertIPV6(cidr, key) 122 } 123 } 124 125 func (t *trieMap) insertIPV4(cidr netip.Prefix, key int) { 126 // ensure root node 127 if t.ipv4Root == nil { 128 t.ipv4Root = &trieNode{} 129 } 130 131 // walk bits high to low, inserting matching ip path up to mask bits 132 curr := t.ipv4Root 133 ip := cidr.Addr().As4() 134 // first cast to uint32 for fast bit access 135 // NOTE: IP addresses are big endian, so the low bits are in the last byte 136 ipInt := uint32(ip[3]) | uint32(ip[2])<<8 | uint32(ip[1])<<16 | uint32(ip[0])<<24 137 bits := cidr.Bits() 138 for i := 31; i >= (32 - bits); i-- { 139 if (ipInt & (uint32(1) << i)) != 0 { 140 if curr.child1 == nil { 141 curr.child1 = &trieNode{} 142 } 143 curr = curr.child1 144 } else { 145 if curr.child0 == nil { 146 curr.child0 = &trieNode{} 147 } 148 curr = curr.child0 149 } 150 } 151 curr.value = &nodeValue{ 152 cidr: cidr, 153 key: key, 154 } 155 } 156 157 func (t *trieMap) insertIPV6(cidr netip.Prefix, key int) { 158 // ensure root node 159 if t.ipv6Root == nil { 160 t.ipv6Root = &trieNode{} 161 } 162 163 // walk bits high to low, inserting matching ip path up to mask bits 164 curr := t.ipv6Root 165 ip := cidr.Addr().As16() 166 bits := cidr.Bits() 167 // first cast ip to two uint64 for fast bit access 168 // NOTE: IP addresses are big endian, so the low bits are in the last byte 169 ipLo := uint64(ip[15]) | uint64(ip[14])<<8 | uint64(ip[13])<<16 | uint64(ip[12])<<24 | 170 uint64(ip[11])<<32 | uint64(ip[10])<<40 | uint64(ip[9])<<48 | uint64(ip[8])<<56 171 ipHi := uint64(ip[7]) | uint64(ip[6])<<8 | uint64(ip[5])<<16 | uint64(ip[4])<<24 | 172 uint64(ip[3])<<32 | uint64(ip[2])<<40 | uint64(ip[1])<<48 | uint64(ip[0])<<56 173 for i := 127; i >= (128 - bits); i-- { 174 bit := false 175 if i > 63 { 176 bit = (ipHi & (uint64(1) << (i - 64))) != 0 177 } else { 178 bit = (ipLo & (uint64(1) << i)) != 0 179 } 180 if bit { 181 if curr.child1 == nil { 182 curr.child1 = &trieNode{} 183 } 184 curr = curr.child1 185 } else { 186 if curr.child0 == nil { 187 curr.child0 = &trieNode{} 188 } 189 curr = curr.child0 190 } 191 } 192 curr.value = &nodeValue{ 193 cidr: cidr, 194 key: key, 195 } 196 } 197 198 func (t *trieMap) GetIP(ip netip.Addr) (int, bool) { 199 if ip.Is4() { 200 return t.getIPv4(ip) 201 } 202 return t.getIPv6(ip) 203 } 204 205 func (t *trieMap) getIPv4(addr netip.Addr) (int, bool) { 206 // check the root first 207 curr := t.ipv4Root 208 if curr == nil { 209 return -1, false 210 } 211 if curr.value != nil && curr.value.cidr.Contains(addr) { 212 return curr.value.key, true 213 } 214 // walk IP bits high to low, checking if current node matches 215 ip := addr.As4() 216 // first cast to uint32 for fast bit access 217 // NOTE: IP addresses are big endian, so the low bits are in the last byte 218 ipInt := uint32(ip[3]) | uint32(ip[2])<<8 | uint32(ip[1])<<16 | uint32(ip[0])<<24 219 for i := 31; i >= 0; i-- { 220 // walk based on current address bit 221 if (ipInt & (uint32(1) << i)) != 0 { 222 if curr.child1 != nil { 223 curr = curr.child1 224 } else { 225 // dead end 226 break 227 } 228 } else { 229 if curr.child0 != nil { 230 curr = curr.child0 231 } else { 232 // dead end 233 break 234 } 235 } 236 // check for a match in the current node 237 if curr.value != nil && curr.value.cidr.Contains(addr) { 238 return curr.value.key, true 239 } 240 } 241 return -1, false 242 } 243 244 func (t *trieMap) getIPv6(addr netip.Addr) (int, bool) { 245 // check the root first 246 curr := t.ipv6Root 247 if curr == nil { 248 return -1, false 249 } 250 if curr.value != nil && curr.value.cidr.Contains(addr) { 251 return curr.value.key, true 252 } 253 // walk IP bits high to low, checking if current node matches 254 // first cast ip to two uint64 for fast bit access 255 ip := addr.As16() 256 // NOTE: IP addresses are big endian, so the low bits are in the last byte 257 ipLo := uint64(ip[15]) | uint64(ip[14])<<8 | uint64(ip[13])<<16 | uint64(ip[12])<<24 | 258 uint64(ip[11])<<32 | uint64(ip[10])<<40 | uint64(ip[9])<<48 | uint64(ip[8])<<56 259 ipHi := uint64(ip[7]) | uint64(ip[6])<<8 | uint64(ip[5])<<16 | uint64(ip[4])<<24 | 260 uint64(ip[3])<<32 | uint64(ip[2])<<40 | uint64(ip[1])<<48 | uint64(ip[0])<<56 261 for i := 127; i >= 0; i-- { 262 bit := false 263 if i > 63 { 264 bit = (ipHi & (uint64(1) << (i - 64))) != 0 265 } else { 266 bit = (ipLo & (uint64(1) << i)) != 0 267 } 268 // walk based on current address bit 269 if bit { 270 if curr.child1 != nil { 271 curr = curr.child1 272 } else { 273 // dead end 274 break 275 } 276 } else { 277 if curr.child0 != nil { 278 curr = curr.child0 279 } else { 280 // dead end 281 break 282 } 283 } 284 // check for a match in the current node 285 if curr.value != nil && curr.value.cidr.Contains(addr) { 286 return curr.value.key, true 287 } 288 } 289 return -1, false 290 }