github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/device/allowedips.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "container/list" 10 "errors" 11 "math/bits" 12 "net" 13 "sync" 14 "unsafe" 15 ) 16 17 type trieEntry struct { 18 child [2]*trieEntry 19 peer *Peer 20 bits net.IP 21 cidr uint 22 bit_at_byte uint 23 bit_at_shift uint 24 perPeerElem *list.Element 25 } 26 27 func isLittleEndian() bool { 28 one := uint32(1) 29 return *(*byte)(unsafe.Pointer(&one)) != 0 30 } 31 32 func swapU32(i uint32) uint32 { 33 if !isLittleEndian() { 34 return i 35 } 36 37 return bits.ReverseBytes32(i) 38 } 39 40 func swapU64(i uint64) uint64 { 41 if !isLittleEndian() { 42 return i 43 } 44 45 return bits.ReverseBytes64(i) 46 } 47 48 func commonBits(ip1 net.IP, ip2 net.IP) uint { 49 size := len(ip1) 50 if size == net.IPv4len { 51 a := (*uint32)(unsafe.Pointer(&ip1[0])) 52 b := (*uint32)(unsafe.Pointer(&ip2[0])) 53 x := *a ^ *b 54 return uint(bits.LeadingZeros32(swapU32(x))) 55 } else if size == net.IPv6len { 56 a := (*uint64)(unsafe.Pointer(&ip1[0])) 57 b := (*uint64)(unsafe.Pointer(&ip2[0])) 58 x := *a ^ *b 59 if x != 0 { 60 return uint(bits.LeadingZeros64(swapU64(x))) 61 } 62 a = (*uint64)(unsafe.Pointer(&ip1[8])) 63 b = (*uint64)(unsafe.Pointer(&ip2[8])) 64 x = *a ^ *b 65 return 64 + uint(bits.LeadingZeros64(swapU64(x))) 66 } else { 67 panic("Wrong size bit string") 68 } 69 } 70 71 func (node *trieEntry) addToPeerEntries() { 72 node.perPeerElem = node.peer.trieEntries.PushBack(node) 73 } 74 75 func (node *trieEntry) removeFromPeerEntries() { 76 if node.perPeerElem != nil { 77 node.peer.trieEntries.Remove(node.perPeerElem) 78 node.perPeerElem = nil 79 } 80 } 81 82 func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { 83 if node == nil { 84 return node 85 } 86 87 // walk recursively 88 89 node.child[0] = node.child[0].removeByPeer(p) 90 node.child[1] = node.child[1].removeByPeer(p) 91 92 if node.peer != p { 93 return node 94 } 95 96 // remove peer & merge 97 98 node.removeFromPeerEntries() 99 node.peer = nil 100 if node.child[0] == nil { 101 return node.child[1] 102 } 103 return node.child[0] 104 } 105 106 func (node *trieEntry) choose(ip net.IP) byte { 107 return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 108 } 109 110 func (node *trieEntry) maskSelf() { 111 mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) 112 for i := 0; i < len(mask); i++ { 113 node.bits[i] &= mask[i] 114 } 115 } 116 117 func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { 118 119 // at leaf 120 121 if node == nil { 122 node := &trieEntry{ 123 bits: ip, 124 peer: peer, 125 cidr: cidr, 126 bit_at_byte: cidr / 8, 127 bit_at_shift: 7 - (cidr % 8), 128 } 129 node.maskSelf() 130 node.addToPeerEntries() 131 return node 132 } 133 134 // traverse deeper 135 136 common := commonBits(node.bits, ip) 137 if node.cidr <= cidr && common >= node.cidr { 138 if node.cidr == cidr { 139 node.removeFromPeerEntries() 140 node.peer = peer 141 node.addToPeerEntries() 142 return node 143 } 144 bit := node.choose(ip) 145 node.child[bit] = node.child[bit].insert(ip, cidr, peer) 146 return node 147 } 148 149 // split node 150 151 newNode := &trieEntry{ 152 bits: ip, 153 peer: peer, 154 cidr: cidr, 155 bit_at_byte: cidr / 8, 156 bit_at_shift: 7 - (cidr % 8), 157 } 158 newNode.maskSelf() 159 newNode.addToPeerEntries() 160 161 cidr = min(cidr, common) 162 163 // check for shorter prefix 164 165 if newNode.cidr == cidr { 166 bit := newNode.choose(node.bits) 167 newNode.child[bit] = node 168 return newNode 169 } 170 171 // create new parent for node & newNode 172 173 parent := &trieEntry{ 174 bits: append([]byte{}, ip...), 175 peer: nil, 176 cidr: cidr, 177 bit_at_byte: cidr / 8, 178 bit_at_shift: 7 - (cidr % 8), 179 } 180 parent.maskSelf() 181 182 bit := parent.choose(ip) 183 parent.child[bit] = newNode 184 parent.child[bit^1] = node 185 186 return parent 187 } 188 189 func (node *trieEntry) lookup(ip net.IP) *Peer { 190 var found *Peer 191 size := uint(len(ip)) 192 for node != nil && commonBits(node.bits, ip) >= node.cidr { 193 if node.peer != nil { 194 found = node.peer 195 } 196 if node.bit_at_byte == size { 197 break 198 } 199 bit := node.choose(ip) 200 node = node.child[bit] 201 } 202 return found 203 } 204 205 type AllowedIPs struct { 206 IPv4 *trieEntry 207 IPv6 *trieEntry 208 mutex sync.RWMutex 209 } 210 211 func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) { 212 table.mutex.RLock() 213 defer table.mutex.RUnlock() 214 215 for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { 216 node := elem.Value.(*trieEntry) 217 if !cb(node.bits, node.cidr) { 218 return 219 } 220 } 221 } 222 223 func (table *AllowedIPs) RemoveByPeer(peer *Peer) { 224 table.mutex.Lock() 225 defer table.mutex.Unlock() 226 227 table.IPv4 = table.IPv4.removeByPeer(peer) 228 table.IPv6 = table.IPv6.removeByPeer(peer) 229 } 230 231 func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { 232 table.mutex.Lock() 233 defer table.mutex.Unlock() 234 235 switch len(ip) { 236 case net.IPv6len: 237 table.IPv6 = table.IPv6.insert(ip, cidr, peer) 238 case net.IPv4len: 239 table.IPv4 = table.IPv4.insert(ip, cidr, peer) 240 default: 241 panic(errors.New("inserting unknown address type")) 242 } 243 } 244 245 func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { 246 table.mutex.RLock() 247 defer table.mutex.RUnlock() 248 return table.IPv4.lookup(address) 249 } 250 251 func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { 252 table.mutex.RLock() 253 defer table.mutex.RUnlock() 254 return table.IPv6.lookup(address) 255 }