k8s.io/registry.k8s.io@v0.3.1/pkg/net/cidrs/triemap.go (about)

     1  /*
     2  Copyright 2022 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package cidrs
    18  
    19  import (
    20  	"net/netip"
    21  )
    22  
    23  // TrieMap contains an efficient trie structure of netip.Prefix that can
    24  // match a netip.Addr to the associated Prefix if any and return the value
    25  // associated with it of type V.
    26  //
    27  // # Use NewTrieMap to instantiate
    28  //
    29  // NOTE: This is insert-only (no delete) and insertion is *not* thread-safe.
    30  //
    31  // Currently this is a simple TrieMap, in the future it may have compression.
    32  //
    33  // See: https://vincent.bernat.ch/en/blog/2017-ipv4-route-lookup-linux
    34  //
    35  // For benchmarks with real data see ./aws/mapper_test.go
    36  type TrieMap[V comparable] struct {
    37  	// This is the real triemap, but it only maps netip.Prefix / netip.Addr : int
    38  	// see: https://planetscale.com/blog/generics-can-make-your-go-code-slower
    39  	// The maps below map from int in this trie to generic value type V
    40  	//
    41  	// This is also cheaper in many cases because int will be smaller than V
    42  	// so we can store V only once in the map here, and int indexes into those
    43  	// maps in the trie structure, given than many trie nodes will map to the same
    44  	// V, as our target use-case is CIDR-to-cloud-region
    45  	trieMap trieMap
    46  
    47  	// simple inline bimap of int keys to V values
    48  	//
    49  	// the inner trie stores an int key index into keyToValue
    50  	//
    51  	// valueToKey is to cheapen checking if we've already inserted a given V
    52  	// and use the same key
    53  	keyToValue map[int]V
    54  	valueToKey map[V]int
    55  }
    56  
    57  // NewTrieMap[V] returns a new, properly allocated TrieMap[V]
    58  func NewTrieMap[V comparable]() *TrieMap[V] {
    59  	return &TrieMap[V]{
    60  		keyToValue: make(map[int]V),
    61  		valueToKey: make(map[V]int),
    62  	}
    63  }
    64  
    65  // Insert inserts value into TrieMap by index cidr
    66  // You can later match a netip.Addr to value with GetIP
    67  func (t *TrieMap[V]) Insert(cidr netip.Prefix, value V) {
    68  	key, alreadyHave := t.valueToKey[value]
    69  	if !alreadyHave {
    70  		// next key = length of map
    71  		// this structure is insert-only
    72  		key = len(t.keyToValue)
    73  		t.valueToKey[value] = key
    74  		t.keyToValue[key] = value
    75  	}
    76  	t.trieMap.Insert(cidr, key)
    77  }
    78  
    79  // GetIP returns the associated value for the matching cidr if any with contains=true,
    80  // or else the default value of V and contains=false
    81  func (t *TrieMap[V]) GetIP(ip netip.Addr) (value V, contains bool) {
    82  	// NOTE: this is written so as not to shadow contains locally
    83  	// and so we can use value as a default-value for V without
    84  	// another variable, using the name also to document the return
    85  	key, c := t.trieMap.GetIP(ip)
    86  	contains = c
    87  	if !contains {
    88  		return
    89  	}
    90  	value = t.keyToValue[key]
    91  	return
    92  }
    93  
    94  // trieMap is the core implementation, but it only stores netip.Prefix : int
    95  type trieMap struct {
    96  	// surely ipv4 and ipv6 will be enough in our lifetime?
    97  	ipv4Root *trieNode
    98  	ipv6Root *trieNode
    99  }
   100  
   101  // TODO: path compression
   102  type trieNode struct {
   103  	// children for 0 and 1 bits
   104  	child0 *trieNode
   105  	child1 *trieNode
   106  	// both of these values will be set together or not set
   107  	// so we place them in a sub struct to save memory at the cost
   108  	// of chasing one additional pointer per trie node checked
   109  	value *nodeValue
   110  }
   111  
   112  type nodeValue struct {
   113  	cidr netip.Prefix
   114  	key  int
   115  }
   116  
   117  func (t *trieMap) Insert(cidr netip.Prefix, key int) {
   118  	if cidr.Addr().Is4() {
   119  		t.insertIPV4(cidr, key)
   120  	} else {
   121  		t.insertIPV6(cidr, key)
   122  	}
   123  }
   124  
   125  func (t *trieMap) insertIPV4(cidr netip.Prefix, key int) {
   126  	// ensure root node
   127  	if t.ipv4Root == nil {
   128  		t.ipv4Root = &trieNode{}
   129  	}
   130  
   131  	// walk bits high to low, inserting matching ip path up to mask bits
   132  	curr := t.ipv4Root
   133  	ip := cidr.Addr().As4()
   134  	// first cast to uint32 for fast bit access
   135  	// NOTE: IP addresses are big endian, so the low bits are in the last byte
   136  	ipInt := uint32(ip[3]) | uint32(ip[2])<<8 | uint32(ip[1])<<16 | uint32(ip[0])<<24
   137  	bits := cidr.Bits()
   138  	for i := 31; i >= (32 - bits); i-- {
   139  		if (ipInt & (uint32(1) << i)) != 0 {
   140  			if curr.child1 == nil {
   141  				curr.child1 = &trieNode{}
   142  			}
   143  			curr = curr.child1
   144  		} else {
   145  			if curr.child0 == nil {
   146  				curr.child0 = &trieNode{}
   147  			}
   148  			curr = curr.child0
   149  		}
   150  	}
   151  	curr.value = &nodeValue{
   152  		cidr: cidr,
   153  		key:  key,
   154  	}
   155  }
   156  
   157  func (t *trieMap) insertIPV6(cidr netip.Prefix, key int) {
   158  	// ensure root node
   159  	if t.ipv6Root == nil {
   160  		t.ipv6Root = &trieNode{}
   161  	}
   162  
   163  	// walk bits high to low, inserting matching ip path up to mask bits
   164  	curr := t.ipv6Root
   165  	ip := cidr.Addr().As16()
   166  	bits := cidr.Bits()
   167  	// first cast ip to two uint64 for fast bit access
   168  	// NOTE: IP addresses are big endian, so the low bits are in the last byte
   169  	ipLo := uint64(ip[15]) | uint64(ip[14])<<8 | uint64(ip[13])<<16 | uint64(ip[12])<<24 |
   170  		uint64(ip[11])<<32 | uint64(ip[10])<<40 | uint64(ip[9])<<48 | uint64(ip[8])<<56
   171  	ipHi := uint64(ip[7]) | uint64(ip[6])<<8 | uint64(ip[5])<<16 | uint64(ip[4])<<24 |
   172  		uint64(ip[3])<<32 | uint64(ip[2])<<40 | uint64(ip[1])<<48 | uint64(ip[0])<<56
   173  	for i := 127; i >= (128 - bits); i-- {
   174  		bit := false
   175  		if i > 63 {
   176  			bit = (ipHi & (uint64(1) << (i - 64))) != 0
   177  		} else {
   178  			bit = (ipLo & (uint64(1) << i)) != 0
   179  		}
   180  		if bit {
   181  			if curr.child1 == nil {
   182  				curr.child1 = &trieNode{}
   183  			}
   184  			curr = curr.child1
   185  		} else {
   186  			if curr.child0 == nil {
   187  				curr.child0 = &trieNode{}
   188  			}
   189  			curr = curr.child0
   190  		}
   191  	}
   192  	curr.value = &nodeValue{
   193  		cidr: cidr,
   194  		key:  key,
   195  	}
   196  }
   197  
   198  func (t *trieMap) GetIP(ip netip.Addr) (int, bool) {
   199  	if ip.Is4() {
   200  		return t.getIPv4(ip)
   201  	}
   202  	return t.getIPv6(ip)
   203  }
   204  
   205  func (t *trieMap) getIPv4(addr netip.Addr) (int, bool) {
   206  	// check the root first
   207  	curr := t.ipv4Root
   208  	if curr == nil {
   209  		return -1, false
   210  	}
   211  	if curr.value != nil && curr.value.cidr.Contains(addr) {
   212  		return curr.value.key, true
   213  	}
   214  	// walk IP bits high to low, checking if current node matches
   215  	ip := addr.As4()
   216  	// first cast to uint32 for fast bit access
   217  	// NOTE: IP addresses are big endian, so the low bits are in the last byte
   218  	ipInt := uint32(ip[3]) | uint32(ip[2])<<8 | uint32(ip[1])<<16 | uint32(ip[0])<<24
   219  	for i := 31; i >= 0; i-- {
   220  		// walk based on current address bit
   221  		if (ipInt & (uint32(1) << i)) != 0 {
   222  			if curr.child1 != nil {
   223  				curr = curr.child1
   224  			} else {
   225  				// dead end
   226  				break
   227  			}
   228  		} else {
   229  			if curr.child0 != nil {
   230  				curr = curr.child0
   231  			} else {
   232  				// dead end
   233  				break
   234  			}
   235  		}
   236  		// check for a match in the current node
   237  		if curr.value != nil && curr.value.cidr.Contains(addr) {
   238  			return curr.value.key, true
   239  		}
   240  	}
   241  	return -1, false
   242  }
   243  
   244  func (t *trieMap) getIPv6(addr netip.Addr) (int, bool) {
   245  	// check the root first
   246  	curr := t.ipv6Root
   247  	if curr == nil {
   248  		return -1, false
   249  	}
   250  	if curr.value != nil && curr.value.cidr.Contains(addr) {
   251  		return curr.value.key, true
   252  	}
   253  	// walk IP bits high to low, checking if current node matches
   254  	// first cast ip to two uint64 for fast bit access
   255  	ip := addr.As16()
   256  	// NOTE: IP addresses are big endian, so the low bits are in the last byte
   257  	ipLo := uint64(ip[15]) | uint64(ip[14])<<8 | uint64(ip[13])<<16 | uint64(ip[12])<<24 |
   258  		uint64(ip[11])<<32 | uint64(ip[10])<<40 | uint64(ip[9])<<48 | uint64(ip[8])<<56
   259  	ipHi := uint64(ip[7]) | uint64(ip[6])<<8 | uint64(ip[5])<<16 | uint64(ip[4])<<24 |
   260  		uint64(ip[3])<<32 | uint64(ip[2])<<40 | uint64(ip[1])<<48 | uint64(ip[0])<<56
   261  	for i := 127; i >= 0; i-- {
   262  		bit := false
   263  		if i > 63 {
   264  			bit = (ipHi & (uint64(1) << (i - 64))) != 0
   265  		} else {
   266  			bit = (ipLo & (uint64(1) << i)) != 0
   267  		}
   268  		// walk based on current address bit
   269  		if bit {
   270  			if curr.child1 != nil {
   271  				curr = curr.child1
   272  			} else {
   273  				// dead end
   274  				break
   275  			}
   276  		} else {
   277  			if curr.child0 != nil {
   278  				curr = curr.child0
   279  			} else {
   280  				// dead end
   281  				break
   282  			}
   283  		}
   284  		// check for a match in the current node
   285  		if curr.value != nil && curr.value.cidr.Contains(addr) {
   286  			return curr.value.key, true
   287  		}
   288  	}
   289  	return -1, false
   290  }