github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/device/allowedips.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "container/list" 10 "errors" 11 "math/bits" 12 "net" 13 "sync" 14 "unsafe" 15 ) 16 17 type parentIndirection struct { 18 parentBit **trieEntry 19 parentBitType uint8 20 } 21 22 type trieEntry struct { 23 peer *Peer 24 child [2]*trieEntry 25 parent parentIndirection 26 cidr uint8 27 bitAtByte uint8 28 bitAtShift uint8 29 bits net.IP 30 perPeerElem *list.Element 31 } 32 33 func isLittleEndian() bool { 34 one := uint32(1) 35 return *(*byte)(unsafe.Pointer(&one)) != 0 36 } 37 38 func swapU32(i uint32) uint32 { 39 if !isLittleEndian() { 40 return i 41 } 42 43 return bits.ReverseBytes32(i) 44 } 45 46 func swapU64(i uint64) uint64 { 47 if !isLittleEndian() { 48 return i 49 } 50 51 return bits.ReverseBytes64(i) 52 } 53 54 func commonBits(ip1 net.IP, ip2 net.IP) uint8 { 55 size := len(ip1) 56 if size == net.IPv4len { 57 a := (*uint32)(unsafe.Pointer(&ip1[0])) 58 b := (*uint32)(unsafe.Pointer(&ip2[0])) 59 x := *a ^ *b 60 return uint8(bits.LeadingZeros32(swapU32(x))) 61 } else if size == net.IPv6len { 62 a := (*uint64)(unsafe.Pointer(&ip1[0])) 63 b := (*uint64)(unsafe.Pointer(&ip2[0])) 64 x := *a ^ *b 65 if x != 0 { 66 return uint8(bits.LeadingZeros64(swapU64(x))) 67 } 68 a = (*uint64)(unsafe.Pointer(&ip1[8])) 69 b = (*uint64)(unsafe.Pointer(&ip2[8])) 70 x = *a ^ *b 71 return 64 + uint8(bits.LeadingZeros64(swapU64(x))) 72 } else { 73 panic("Wrong size bit string") 74 } 75 } 76 77 func (node *trieEntry) addToPeerEntries() { 78 node.perPeerElem = node.peer.trieEntries.PushBack(node) 79 } 80 81 func (node *trieEntry) removeFromPeerEntries() { 82 if node.perPeerElem != nil { 83 node.peer.trieEntries.Remove(node.perPeerElem) 84 node.perPeerElem = nil 85 } 86 } 87 88 func (node *trieEntry) choose(ip net.IP) byte { 89 return (ip[node.bitAtByte] >> node.bitAtShift) & 1 90 } 91 92 func (node *trieEntry) maskSelf() { 93 mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) 94 for i := 0; i < len(mask); i++ { 95 node.bits[i] &= mask[i] 96 } 97 } 98 99 func (node *trieEntry) zeroizePointers() { 100 // Make the garbage collector's life slightly easier 101 node.peer = nil 102 node.child[0] = nil 103 node.child[1] = nil 104 node.parent.parentBit = nil 105 } 106 107 func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { 108 for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { 109 parent = node 110 if parent.cidr == cidr { 111 exact = true 112 return 113 } 114 bit := node.choose(ip) 115 node = node.child[bit] 116 } 117 return 118 } 119 120 func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { 121 if *trie.parentBit == nil { 122 node := &trieEntry{ 123 peer: peer, 124 parent: trie, 125 bits: ip, 126 cidr: cidr, 127 bitAtByte: cidr / 8, 128 bitAtShift: 7 - (cidr % 8), 129 } 130 node.maskSelf() 131 node.addToPeerEntries() 132 *trie.parentBit = node 133 return 134 } 135 node, exact := (*trie.parentBit).nodePlacement(ip, cidr) 136 if exact { 137 node.removeFromPeerEntries() 138 node.peer = peer 139 node.addToPeerEntries() 140 return 141 } 142 143 newNode := &trieEntry{ 144 peer: peer, 145 bits: ip, 146 cidr: cidr, 147 bitAtByte: cidr / 8, 148 bitAtShift: 7 - (cidr % 8), 149 } 150 newNode.maskSelf() 151 newNode.addToPeerEntries() 152 153 var down *trieEntry 154 if node == nil { 155 down = *trie.parentBit 156 } else { 157 bit := node.choose(ip) 158 down = node.child[bit] 159 if down == nil { 160 newNode.parent = parentIndirection{&node.child[bit], bit} 161 node.child[bit] = newNode 162 return 163 } 164 } 165 common := commonBits(down.bits, ip) 166 if common < cidr { 167 cidr = common 168 } 169 parent := node 170 171 if newNode.cidr == cidr { 172 bit := newNode.choose(down.bits) 173 down.parent = parentIndirection{&newNode.child[bit], bit} 174 newNode.child[bit] = down 175 if parent == nil { 176 newNode.parent = trie 177 *trie.parentBit = newNode 178 } else { 179 bit := parent.choose(newNode.bits) 180 newNode.parent = parentIndirection{&parent.child[bit], bit} 181 parent.child[bit] = newNode 182 } 183 return 184 } 185 186 node = &trieEntry{ 187 bits: append([]byte{}, newNode.bits...), 188 cidr: cidr, 189 bitAtByte: cidr / 8, 190 bitAtShift: 7 - (cidr % 8), 191 } 192 node.maskSelf() 193 194 bit := node.choose(down.bits) 195 down.parent = parentIndirection{&node.child[bit], bit} 196 node.child[bit] = down 197 bit = node.choose(newNode.bits) 198 newNode.parent = parentIndirection{&node.child[bit], bit} 199 node.child[bit] = newNode 200 if parent == nil { 201 node.parent = trie 202 *trie.parentBit = node 203 } else { 204 bit := parent.choose(node.bits) 205 node.parent = parentIndirection{&parent.child[bit], bit} 206 parent.child[bit] = node 207 } 208 } 209 210 func (node *trieEntry) lookup(ip net.IP) *Peer { 211 var found *Peer 212 size := uint8(len(ip)) 213 for node != nil && commonBits(node.bits, ip) >= node.cidr { 214 if node.peer != nil { 215 found = node.peer 216 } 217 if node.bitAtByte == size { 218 break 219 } 220 bit := node.choose(ip) 221 node = node.child[bit] 222 } 223 return found 224 } 225 226 type AllowedIPs struct { 227 IPv4 *trieEntry 228 IPv6 *trieEntry 229 mutex sync.RWMutex 230 } 231 232 func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { 233 table.mutex.RLock() 234 defer table.mutex.RUnlock() 235 236 for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { 237 node := elem.Value.(*trieEntry) 238 if !cb(node.bits, node.cidr) { 239 return 240 } 241 } 242 } 243 244 func (table *AllowedIPs) RemoveByPeer(peer *Peer) { 245 table.mutex.Lock() 246 defer table.mutex.Unlock() 247 248 var next *list.Element 249 for elem := peer.trieEntries.Front(); elem != nil; elem = next { 250 next = elem.Next() 251 node := elem.Value.(*trieEntry) 252 253 node.removeFromPeerEntries() 254 node.peer = nil 255 if node.child[0] != nil && node.child[1] != nil { 256 continue 257 } 258 bit := 0 259 if node.child[0] == nil { 260 bit = 1 261 } 262 child := node.child[bit] 263 if child != nil { 264 child.parent = node.parent 265 } 266 *node.parent.parentBit = child 267 if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { 268 node.zeroizePointers() 269 continue 270 } 271 parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) 272 if parent.peer != nil { 273 node.zeroizePointers() 274 continue 275 } 276 child = parent.child[node.parent.parentBitType^1] 277 if child != nil { 278 child.parent = parent.parent 279 } 280 *parent.parent.parentBit = child 281 node.zeroizePointers() 282 parent.zeroizePointers() 283 } 284 } 285 286 func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { 287 table.mutex.Lock() 288 defer table.mutex.Unlock() 289 290 switch len(ip) { 291 case net.IPv6len: 292 parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer) 293 case net.IPv4len: 294 parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer) 295 default: 296 panic(errors.New("inserting unknown address type")) 297 } 298 } 299 300 func (table *AllowedIPs) Lookup(address []byte) *Peer { 301 table.mutex.RLock() 302 defer table.mutex.RUnlock() 303 switch len(address) { 304 case net.IPv6len: 305 return table.IPv6.lookup(address) 306 case net.IPv4len: 307 return table.IPv4.lookup(address) 308 default: 309 panic(errors.New("looking up unknown address type")) 310 } 311 }