github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/allowedips.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "container/list" 10 "encoding/binary" 11 "errors" 12 "math/bits" 13 "net" 14 "net/netip" 15 "sync" 16 "unsafe" 17 ) 18 19 type parentIndirection struct { 20 parentBit **trieEntry 21 parentBitType uint8 22 } 23 24 type trieEntry struct { 25 peer *Peer 26 child [2]*trieEntry 27 parent parentIndirection 28 cidr uint8 29 bitAtByte uint8 30 bitAtShift uint8 31 bits []byte 32 perPeerElem *list.Element 33 } 34 35 func commonBits(ip1, ip2 []byte) uint8 { 36 size := len(ip1) 37 if size == net.IPv4len { 38 a := binary.BigEndian.Uint32(ip1) 39 b := binary.BigEndian.Uint32(ip2) 40 x := a ^ b 41 return uint8(bits.LeadingZeros32(x)) 42 } else if size == net.IPv6len { 43 a := binary.BigEndian.Uint64(ip1) 44 b := binary.BigEndian.Uint64(ip2) 45 x := a ^ b 46 if x != 0 { 47 return uint8(bits.LeadingZeros64(x)) 48 } 49 a = binary.BigEndian.Uint64(ip1[8:]) 50 b = binary.BigEndian.Uint64(ip2[8:]) 51 x = a ^ b 52 return 64 + uint8(bits.LeadingZeros64(x)) 53 } else { 54 panic("Wrong size bit string") 55 } 56 } 57 58 func (node *trieEntry) addToPeerEntries() { 59 node.perPeerElem = node.peer.trieEntries.PushBack(node) 60 } 61 62 func (node *trieEntry) removeFromPeerEntries() { 63 if node.perPeerElem != nil { 64 node.peer.trieEntries.Remove(node.perPeerElem) 65 node.perPeerElem = nil 66 } 67 } 68 69 func (node *trieEntry) choose(ip []byte) byte { 70 return (ip[node.bitAtByte] >> node.bitAtShift) & 1 71 } 72 73 func (node *trieEntry) maskSelf() { 74 mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) 75 for i := 0; i < len(mask); i++ { 76 node.bits[i] &= mask[i] 77 } 78 } 79 80 func (node *trieEntry) zeroizePointers() { 81 // Make the garbage collector's life slightly easier 82 node.peer = nil 83 node.child[0] = nil 84 node.child[1] = nil 85 node.parent.parentBit = nil 86 } 87 88 func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { 89 for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { 90 parent = node 91 if parent.cidr == cidr { 92 exact = true 93 return 94 } 95 bit := node.choose(ip) 96 node = node.child[bit] 97 } 98 return 99 } 100 101 func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { 102 if *trie.parentBit == nil { 103 node := &trieEntry{ 104 peer: peer, 105 parent: trie, 106 bits: ip, 107 cidr: cidr, 108 bitAtByte: cidr / 8, 109 bitAtShift: 7 - (cidr % 8), 110 } 111 node.maskSelf() 112 node.addToPeerEntries() 113 *trie.parentBit = node 114 return 115 } 116 node, exact := (*trie.parentBit).nodePlacement(ip, cidr) 117 if exact { 118 node.removeFromPeerEntries() 119 node.peer = peer 120 node.addToPeerEntries() 121 return 122 } 123 124 newNode := &trieEntry{ 125 peer: peer, 126 bits: ip, 127 cidr: cidr, 128 bitAtByte: cidr / 8, 129 bitAtShift: 7 - (cidr % 8), 130 } 131 newNode.maskSelf() 132 newNode.addToPeerEntries() 133 134 var down *trieEntry 135 if node == nil { 136 down = *trie.parentBit 137 } else { 138 bit := node.choose(ip) 139 down = node.child[bit] 140 if down == nil { 141 newNode.parent = parentIndirection{&node.child[bit], bit} 142 node.child[bit] = newNode 143 return 144 } 145 } 146 common := commonBits(down.bits, ip) 147 if common < cidr { 148 cidr = common 149 } 150 parent := node 151 152 if newNode.cidr == cidr { 153 bit := newNode.choose(down.bits) 154 down.parent = parentIndirection{&newNode.child[bit], bit} 155 newNode.child[bit] = down 156 if parent == nil { 157 newNode.parent = trie 158 *trie.parentBit = newNode 159 } else { 160 bit := parent.choose(newNode.bits) 161 newNode.parent = parentIndirection{&parent.child[bit], bit} 162 parent.child[bit] = newNode 163 } 164 return 165 } 166 167 node = &trieEntry{ 168 bits: append([]byte{}, newNode.bits...), 169 cidr: cidr, 170 bitAtByte: cidr / 8, 171 bitAtShift: 7 - (cidr % 8), 172 } 173 node.maskSelf() 174 175 bit := node.choose(down.bits) 176 down.parent = parentIndirection{&node.child[bit], bit} 177 node.child[bit] = down 178 bit = node.choose(newNode.bits) 179 newNode.parent = parentIndirection{&node.child[bit], bit} 180 node.child[bit] = newNode 181 if parent == nil { 182 node.parent = trie 183 *trie.parentBit = node 184 } else { 185 bit := parent.choose(node.bits) 186 node.parent = parentIndirection{&parent.child[bit], bit} 187 parent.child[bit] = node 188 } 189 } 190 191 func (node *trieEntry) lookup(ip []byte) *Peer { 192 var found *Peer 193 size := uint8(len(ip)) 194 for node != nil && commonBits(node.bits, ip) >= node.cidr { 195 if node.peer != nil { 196 found = node.peer 197 } 198 if node.bitAtByte == size { 199 break 200 } 201 bit := node.choose(ip) 202 node = node.child[bit] 203 } 204 return found 205 } 206 207 type AllowedIPs struct { 208 IPv4 *trieEntry 209 IPv6 *trieEntry 210 mutex sync.RWMutex 211 } 212 213 func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { 214 table.mutex.RLock() 215 defer table.mutex.RUnlock() 216 217 for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { 218 node := elem.Value.(*trieEntry) 219 a, _ := netip.AddrFromSlice(node.bits) 220 if !cb(netip.PrefixFrom(a, int(node.cidr))) { 221 return 222 } 223 } 224 } 225 226 func (table *AllowedIPs) RemoveByPeer(peer *Peer) { 227 table.mutex.Lock() 228 defer table.mutex.Unlock() 229 230 var next *list.Element 231 for elem := peer.trieEntries.Front(); elem != nil; elem = next { 232 next = elem.Next() 233 node := elem.Value.(*trieEntry) 234 235 node.removeFromPeerEntries() 236 node.peer = nil 237 if node.child[0] != nil && node.child[1] != nil { 238 continue 239 } 240 bit := 0 241 if node.child[0] == nil { 242 bit = 1 243 } 244 child := node.child[bit] 245 if child != nil { 246 child.parent = node.parent 247 } 248 *node.parent.parentBit = child 249 if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { 250 node.zeroizePointers() 251 continue 252 } 253 parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) 254 if parent.peer != nil { 255 node.zeroizePointers() 256 continue 257 } 258 child = parent.child[node.parent.parentBitType^1] 259 if child != nil { 260 child.parent = parent.parent 261 } 262 *parent.parent.parentBit = child 263 node.zeroizePointers() 264 parent.zeroizePointers() 265 } 266 } 267 268 func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { 269 table.mutex.Lock() 270 defer table.mutex.Unlock() 271 272 if prefix.Addr().Is6() { 273 ip := prefix.Addr().As16() 274 parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) 275 } else if prefix.Addr().Is4() { 276 ip := prefix.Addr().As4() 277 parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) 278 } else { 279 panic(errors.New("inserting unknown address type")) 280 } 281 } 282 283 func (table *AllowedIPs) Lookup(ip []byte) *Peer { 284 table.mutex.RLock() 285 defer table.mutex.RUnlock() 286 switch len(ip) { 287 case net.IPv6len: 288 return table.IPv6.lookup(ip) 289 case net.IPv4len: 290 return table.IPv4.lookup(ip) 291 default: 292 panic(errors.New("looking up unknown address type")) 293 } 294 }