github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/allowedips.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "container/list" 10 "encoding/binary" 11 "errors" 12 "math/bits" 13 "net" 14 "sync" 15 "unsafe" 16 17 "golang.zx2c4.com/go118/netip" 18 ) 19 20 type parentIndirection struct { 21 parentBit **trieEntry 22 parentBitType uint8 23 } 24 25 type trieEntry struct { 26 peer *Peer 27 child [2]*trieEntry 28 parent parentIndirection 29 cidr uint8 30 bitAtByte uint8 31 bitAtShift uint8 32 bits []byte 33 perPeerElem *list.Element 34 } 35 36 func commonBits(ip1, ip2 []byte) uint8 { 37 size := len(ip1) 38 if size == net.IPv4len { 39 a := binary.BigEndian.Uint32(ip1) 40 b := binary.BigEndian.Uint32(ip2) 41 x := a ^ b 42 return uint8(bits.LeadingZeros32(x)) 43 } else if size == net.IPv6len { 44 a := binary.BigEndian.Uint64(ip1) 45 b := binary.BigEndian.Uint64(ip2) 46 x := a ^ b 47 if x != 0 { 48 return uint8(bits.LeadingZeros64(x)) 49 } 50 a = binary.BigEndian.Uint64(ip1[8:]) 51 b = binary.BigEndian.Uint64(ip2[8:]) 52 x = a ^ b 53 return 64 + uint8(bits.LeadingZeros64(x)) 54 } else { 55 panic("Wrong size bit string") 56 } 57 } 58 59 func (node *trieEntry) addToPeerEntries() { 60 node.perPeerElem = node.peer.trieEntries.PushBack(node) 61 } 62 63 func (node *trieEntry) removeFromPeerEntries() { 64 if node.perPeerElem != nil { 65 node.peer.trieEntries.Remove(node.perPeerElem) 66 node.perPeerElem = nil 67 } 68 } 69 70 func (node *trieEntry) choose(ip []byte) byte { 71 return (ip[node.bitAtByte] >> node.bitAtShift) & 1 72 } 73 74 func (node *trieEntry) maskSelf() { 75 mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) 76 for i := 0; i < len(mask); i++ { 77 node.bits[i] &= mask[i] 78 } 79 } 80 81 func (node *trieEntry) zeroizePointers() { 82 // Make the garbage collector's life slightly easier 83 node.peer = nil 84 node.child[0] = nil 85 node.child[1] = nil 86 node.parent.parentBit = nil 87 } 88 89 func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { 90 for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { 91 parent = node 92 if parent.cidr == cidr { 93 exact = true 94 return 95 } 96 bit := node.choose(ip) 97 node = node.child[bit] 98 } 99 return 100 } 101 102 func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { 103 if *trie.parentBit == nil { 104 node := &trieEntry{ 105 peer: peer, 106 parent: trie, 107 bits: ip, 108 cidr: cidr, 109 bitAtByte: cidr / 8, 110 bitAtShift: 7 - (cidr % 8), 111 } 112 node.maskSelf() 113 node.addToPeerEntries() 114 *trie.parentBit = node 115 return 116 } 117 node, exact := (*trie.parentBit).nodePlacement(ip, cidr) 118 if exact { 119 node.removeFromPeerEntries() 120 node.peer = peer 121 node.addToPeerEntries() 122 return 123 } 124 125 newNode := &trieEntry{ 126 peer: peer, 127 bits: ip, 128 cidr: cidr, 129 bitAtByte: cidr / 8, 130 bitAtShift: 7 - (cidr % 8), 131 } 132 newNode.maskSelf() 133 newNode.addToPeerEntries() 134 135 var down *trieEntry 136 if node == nil { 137 down = *trie.parentBit 138 } else { 139 bit := node.choose(ip) 140 down = node.child[bit] 141 if down == nil { 142 newNode.parent = parentIndirection{&node.child[bit], bit} 143 node.child[bit] = newNode 144 return 145 } 146 } 147 common := commonBits(down.bits, ip) 148 if common < cidr { 149 cidr = common 150 } 151 parent := node 152 153 if newNode.cidr == cidr { 154 bit := newNode.choose(down.bits) 155 down.parent = parentIndirection{&newNode.child[bit], bit} 156 newNode.child[bit] = down 157 if parent == nil { 158 newNode.parent = trie 159 *trie.parentBit = newNode 160 } else { 161 bit := parent.choose(newNode.bits) 162 newNode.parent = parentIndirection{&parent.child[bit], bit} 163 parent.child[bit] = newNode 164 } 165 return 166 } 167 168 node = &trieEntry{ 169 bits: append([]byte{}, newNode.bits...), 170 cidr: cidr, 171 bitAtByte: cidr / 8, 172 bitAtShift: 7 - (cidr % 8), 173 } 174 node.maskSelf() 175 176 bit := node.choose(down.bits) 177 down.parent = parentIndirection{&node.child[bit], bit} 178 node.child[bit] = down 179 bit = node.choose(newNode.bits) 180 newNode.parent = parentIndirection{&node.child[bit], bit} 181 node.child[bit] = newNode 182 if parent == nil { 183 node.parent = trie 184 *trie.parentBit = node 185 } else { 186 bit := parent.choose(node.bits) 187 node.parent = parentIndirection{&parent.child[bit], bit} 188 parent.child[bit] = node 189 } 190 } 191 192 func (node *trieEntry) lookup(ip []byte) *Peer { 193 var found *Peer 194 size := uint8(len(ip)) 195 for node != nil && commonBits(node.bits, ip) >= node.cidr { 196 if node.peer != nil { 197 found = node.peer 198 } 199 if node.bitAtByte == size { 200 break 201 } 202 bit := node.choose(ip) 203 node = node.child[bit] 204 } 205 return found 206 } 207 208 type AllowedIPs struct { 209 IPv4 *trieEntry 210 IPv6 *trieEntry 211 mutex sync.RWMutex 212 } 213 214 func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { 215 table.mutex.RLock() 216 defer table.mutex.RUnlock() 217 218 for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { 219 node := elem.Value.(*trieEntry) 220 a, _ := netip.AddrFromSlice(node.bits) 221 if !cb(netip.PrefixFrom(a, int(node.cidr))) { 222 return 223 } 224 } 225 } 226 227 func (table *AllowedIPs) RemoveByPeer(peer *Peer) { 228 table.mutex.Lock() 229 defer table.mutex.Unlock() 230 231 var next *list.Element 232 for elem := peer.trieEntries.Front(); elem != nil; elem = next { 233 next = elem.Next() 234 node := elem.Value.(*trieEntry) 235 236 node.removeFromPeerEntries() 237 node.peer = nil 238 if node.child[0] != nil && node.child[1] != nil { 239 continue 240 } 241 bit := 0 242 if node.child[0] == nil { 243 bit = 1 244 } 245 child := node.child[bit] 246 if child != nil { 247 child.parent = node.parent 248 } 249 *node.parent.parentBit = child 250 if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { 251 node.zeroizePointers() 252 continue 253 } 254 parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) 255 if parent.peer != nil { 256 node.zeroizePointers() 257 continue 258 } 259 child = parent.child[node.parent.parentBitType^1] 260 if child != nil { 261 child.parent = parent.parent 262 } 263 *parent.parent.parentBit = child 264 node.zeroizePointers() 265 parent.zeroizePointers() 266 } 267 } 268 269 func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { 270 table.mutex.Lock() 271 defer table.mutex.Unlock() 272 273 if prefix.Addr().Is6() { 274 ip := prefix.Addr().As16() 275 parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) 276 } else if prefix.Addr().Is4() { 277 ip := prefix.Addr().As4() 278 parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) 279 } else { 280 panic(errors.New("inserting unknown address type")) 281 } 282 } 283 284 func (table *AllowedIPs) Lookup(ip []byte) *Peer { 285 table.mutex.RLock() 286 defer table.mutex.RUnlock() 287 switch len(ip) { 288 case net.IPv6len: 289 return table.IPv6.lookup(ip) 290 case net.IPv4len: 291 return table.IPv4.lookup(ip) 292 default: 293 panic(errors.New("looking up unknown address type")) 294 } 295 }