github.com/slackhq/nebula@v1.9.0/cidr/tree4.go (about) 1 package cidr 2 3 import ( 4 "net" 5 6 "github.com/slackhq/nebula/iputil" 7 ) 8 9 type Node[T any] struct { 10 left *Node[T] 11 right *Node[T] 12 parent *Node[T] 13 hasValue bool 14 value T 15 } 16 17 type entry[T any] struct { 18 CIDR *net.IPNet 19 Value T 20 } 21 22 type Tree4[T any] struct { 23 root *Node[T] 24 list []entry[T] 25 } 26 27 const ( 28 startbit = iputil.VpnIp(0x80000000) 29 ) 30 31 func NewTree4[T any]() *Tree4[T] { 32 tree := new(Tree4[T]) 33 tree.root = &Node[T]{} 34 tree.list = []entry[T]{} 35 return tree 36 } 37 38 func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) { 39 bit := startbit 40 node := tree.root 41 next := tree.root 42 43 ip := iputil.Ip2VpnIp(cidr.IP) 44 mask := iputil.Ip2VpnIp(cidr.Mask) 45 46 // Find our last ancestor in the tree 47 for bit&mask != 0 { 48 if ip&bit != 0 { 49 next = node.right 50 } else { 51 next = node.left 52 } 53 54 if next == nil { 55 break 56 } 57 58 bit = bit >> 1 59 node = next 60 } 61 62 // We already have this range so update the value 63 if next != nil { 64 addCIDR := cidr.String() 65 for i, v := range tree.list { 66 if addCIDR == v.CIDR.String() { 67 tree.list = append(tree.list[:i], tree.list[i+1:]...) 68 break 69 } 70 } 71 72 tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) 73 node.value = val 74 node.hasValue = true 75 return 76 } 77 78 // Build up the rest of the tree we don't already have 79 for bit&mask != 0 { 80 next = &Node[T]{} 81 next.parent = node 82 83 if ip&bit != 0 { 84 node.right = next 85 } else { 86 node.left = next 87 } 88 89 bit >>= 1 90 node = next 91 } 92 93 // Final node marks our cidr, set the value 94 node.value = val 95 node.hasValue = true 96 tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) 97 } 98 99 // Contains finds the first match, which may be the least specific 100 func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) { 101 bit := startbit 102 node := tree.root 103 104 for node != nil { 105 if node.hasValue { 106 return true, node.value 107 } 108 109 if ip&bit != 0 { 110 node = node.right 111 } else { 112 node = node.left 113 } 114 115 bit >>= 1 116 117 } 118 119 return false, value 120 } 121 122 // MostSpecificContains finds the most specific match 123 func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) { 124 bit := startbit 125 node := tree.root 126 127 for node != nil { 128 if node.hasValue { 129 value = node.value 130 ok = true 131 } 132 133 if ip&bit != 0 { 134 node = node.right 135 } else { 136 node = node.left 137 } 138 139 bit >>= 1 140 } 141 142 return ok, value 143 } 144 145 type eachFunc[T any] func(T) bool 146 147 // EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete 148 // The final return value will be true if the provided function returned true 149 func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool { 150 bit := startbit 151 node := tree.root 152 153 for node != nil { 154 if node.hasValue { 155 // If the each func returns true then we can exit the loop 156 if each(node.value) { 157 return true 158 } 159 } 160 161 if ip&bit != 0 { 162 node = node.right 163 } else { 164 node = node.left 165 } 166 167 bit >>= 1 168 } 169 170 return false 171 } 172 173 // GetCIDR returns the entry added by the most recent matching AddCIDR call 174 func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) { 175 bit := startbit 176 node := tree.root 177 178 ip := iputil.Ip2VpnIp(cidr.IP) 179 mask := iputil.Ip2VpnIp(cidr.Mask) 180 181 // Find our last ancestor in the tree 182 for node != nil && bit&mask != 0 { 183 if ip&bit != 0 { 184 node = node.right 185 } else { 186 node = node.left 187 } 188 189 bit = bit >> 1 190 } 191 192 if bit&mask == 0 && node != nil { 193 value = node.value 194 ok = node.hasValue 195 } 196 197 return ok, value 198 } 199 200 // List will return all CIDRs and their current values. Do not modify the contents! 201 func (tree *Tree4[T]) List() []entry[T] { 202 return tree.list 203 }