github.com/cawidtu/notwireguard-go/device@v0.0.0-20230523131112-68e8e5ce9cdf/allowedips.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"container/list"
    10  	"encoding/binary"
    11  	"errors"
    12  	"math/bits"
    13  	"net"
    14  	"net/netip"
    15  	"sync"
    16  	"unsafe"
    17  )
    18  
    19  type parentIndirection struct {
    20  	parentBit     **trieEntry
    21  	parentBitType uint8
    22  }
    23  
    24  type trieEntry struct {
    25  	peer        *Peer
    26  	child       [2]*trieEntry
    27  	parent      parentIndirection
    28  	cidr        uint8
    29  	bitAtByte   uint8
    30  	bitAtShift  uint8
    31  	bits        []byte
    32  	perPeerElem *list.Element
    33  }
    34  
    35  func commonBits(ip1, ip2 []byte) uint8 {
    36  	size := len(ip1)
    37  	if size == net.IPv4len {
    38  		a := binary.BigEndian.Uint32(ip1)
    39  		b := binary.BigEndian.Uint32(ip2)
    40  		x := a ^ b
    41  		return uint8(bits.LeadingZeros32(x))
    42  	} else if size == net.IPv6len {
    43  		a := binary.BigEndian.Uint64(ip1)
    44  		b := binary.BigEndian.Uint64(ip2)
    45  		x := a ^ b
    46  		if x != 0 {
    47  			return uint8(bits.LeadingZeros64(x))
    48  		}
    49  		a = binary.BigEndian.Uint64(ip1[8:])
    50  		b = binary.BigEndian.Uint64(ip2[8:])
    51  		x = a ^ b
    52  		return 64 + uint8(bits.LeadingZeros64(x))
    53  	} else {
    54  		panic("Wrong size bit string")
    55  	}
    56  }
    57  
    58  func (node *trieEntry) addToPeerEntries() {
    59  	node.perPeerElem = node.peer.trieEntries.PushBack(node)
    60  }
    61  
    62  func (node *trieEntry) removeFromPeerEntries() {
    63  	if node.perPeerElem != nil {
    64  		node.peer.trieEntries.Remove(node.perPeerElem)
    65  		node.perPeerElem = nil
    66  	}
    67  }
    68  
    69  func (node *trieEntry) choose(ip []byte) byte {
    70  	return (ip[node.bitAtByte] >> node.bitAtShift) & 1
    71  }
    72  
    73  func (node *trieEntry) maskSelf() {
    74  	mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
    75  	for i := 0; i < len(mask); i++ {
    76  		node.bits[i] &= mask[i]
    77  	}
    78  }
    79  
    80  func (node *trieEntry) zeroizePointers() {
    81  	// Make the garbage collector's life slightly easier
    82  	node.peer = nil
    83  	node.child[0] = nil
    84  	node.child[1] = nil
    85  	node.parent.parentBit = nil
    86  }
    87  
    88  func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
    89  	for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
    90  		parent = node
    91  		if parent.cidr == cidr {
    92  			exact = true
    93  			return
    94  		}
    95  		bit := node.choose(ip)
    96  		node = node.child[bit]
    97  	}
    98  	return
    99  }
   100  
   101  func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
   102  	if *trie.parentBit == nil {
   103  		node := &trieEntry{
   104  			peer:       peer,
   105  			parent:     trie,
   106  			bits:       ip,
   107  			cidr:       cidr,
   108  			bitAtByte:  cidr / 8,
   109  			bitAtShift: 7 - (cidr % 8),
   110  		}
   111  		node.maskSelf()
   112  		node.addToPeerEntries()
   113  		*trie.parentBit = node
   114  		return
   115  	}
   116  	node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
   117  	if exact {
   118  		node.removeFromPeerEntries()
   119  		node.peer = peer
   120  		node.addToPeerEntries()
   121  		return
   122  	}
   123  
   124  	newNode := &trieEntry{
   125  		peer:       peer,
   126  		bits:       ip,
   127  		cidr:       cidr,
   128  		bitAtByte:  cidr / 8,
   129  		bitAtShift: 7 - (cidr % 8),
   130  	}
   131  	newNode.maskSelf()
   132  	newNode.addToPeerEntries()
   133  
   134  	var down *trieEntry
   135  	if node == nil {
   136  		down = *trie.parentBit
   137  	} else {
   138  		bit := node.choose(ip)
   139  		down = node.child[bit]
   140  		if down == nil {
   141  			newNode.parent = parentIndirection{&node.child[bit], bit}
   142  			node.child[bit] = newNode
   143  			return
   144  		}
   145  	}
   146  	common := commonBits(down.bits, ip)
   147  	if common < cidr {
   148  		cidr = common
   149  	}
   150  	parent := node
   151  
   152  	if newNode.cidr == cidr {
   153  		bit := newNode.choose(down.bits)
   154  		down.parent = parentIndirection{&newNode.child[bit], bit}
   155  		newNode.child[bit] = down
   156  		if parent == nil {
   157  			newNode.parent = trie
   158  			*trie.parentBit = newNode
   159  		} else {
   160  			bit := parent.choose(newNode.bits)
   161  			newNode.parent = parentIndirection{&parent.child[bit], bit}
   162  			parent.child[bit] = newNode
   163  		}
   164  		return
   165  	}
   166  
   167  	node = &trieEntry{
   168  		bits:       append([]byte{}, newNode.bits...),
   169  		cidr:       cidr,
   170  		bitAtByte:  cidr / 8,
   171  		bitAtShift: 7 - (cidr % 8),
   172  	}
   173  	node.maskSelf()
   174  
   175  	bit := node.choose(down.bits)
   176  	down.parent = parentIndirection{&node.child[bit], bit}
   177  	node.child[bit] = down
   178  	bit = node.choose(newNode.bits)
   179  	newNode.parent = parentIndirection{&node.child[bit], bit}
   180  	node.child[bit] = newNode
   181  	if parent == nil {
   182  		node.parent = trie
   183  		*trie.parentBit = node
   184  	} else {
   185  		bit := parent.choose(node.bits)
   186  		node.parent = parentIndirection{&parent.child[bit], bit}
   187  		parent.child[bit] = node
   188  	}
   189  }
   190  
   191  func (node *trieEntry) lookup(ip []byte) *Peer {
   192  	var found *Peer
   193  	size := uint8(len(ip))
   194  	for node != nil && commonBits(node.bits, ip) >= node.cidr {
   195  		if node.peer != nil {
   196  			found = node.peer
   197  		}
   198  		if node.bitAtByte == size {
   199  			break
   200  		}
   201  		bit := node.choose(ip)
   202  		node = node.child[bit]
   203  	}
   204  	return found
   205  }
   206  
   207  type AllowedIPs struct {
   208  	IPv4  *trieEntry
   209  	IPv6  *trieEntry
   210  	mutex sync.RWMutex
   211  }
   212  
   213  func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
   214  	table.mutex.RLock()
   215  	defer table.mutex.RUnlock()
   216  
   217  	for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
   218  		node := elem.Value.(*trieEntry)
   219  		a, _ := netip.AddrFromSlice(node.bits)
   220  		if !cb(netip.PrefixFrom(a, int(node.cidr))) {
   221  			return
   222  		}
   223  	}
   224  }
   225  
   226  func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
   227  	table.mutex.Lock()
   228  	defer table.mutex.Unlock()
   229  
   230  	var next *list.Element
   231  	for elem := peer.trieEntries.Front(); elem != nil; elem = next {
   232  		next = elem.Next()
   233  		node := elem.Value.(*trieEntry)
   234  
   235  		node.removeFromPeerEntries()
   236  		node.peer = nil
   237  		if node.child[0] != nil && node.child[1] != nil {
   238  			continue
   239  		}
   240  		bit := 0
   241  		if node.child[0] == nil {
   242  			bit = 1
   243  		}
   244  		child := node.child[bit]
   245  		if child != nil {
   246  			child.parent = node.parent
   247  		}
   248  		*node.parent.parentBit = child
   249  		if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
   250  			node.zeroizePointers()
   251  			continue
   252  		}
   253  		parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
   254  		if parent.peer != nil {
   255  			node.zeroizePointers()
   256  			continue
   257  		}
   258  		child = parent.child[node.parent.parentBitType^1]
   259  		if child != nil {
   260  			child.parent = parent.parent
   261  		}
   262  		*parent.parent.parentBit = child
   263  		node.zeroizePointers()
   264  		parent.zeroizePointers()
   265  	}
   266  }
   267  
   268  func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
   269  	table.mutex.Lock()
   270  	defer table.mutex.Unlock()
   271  
   272  	if prefix.Addr().Is6() {
   273  		ip := prefix.Addr().As16()
   274  		parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
   275  	} else if prefix.Addr().Is4() {
   276  		ip := prefix.Addr().As4()
   277  		parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
   278  	} else {
   279  		panic(errors.New("inserting unknown address type"))
   280  	}
   281  }
   282  
   283  func (table *AllowedIPs) Lookup(ip []byte) *Peer {
   284  	table.mutex.RLock()
   285  	defer table.mutex.RUnlock()
   286  	switch len(ip) {
   287  	case net.IPv6len:
   288  		return table.IPv6.lookup(ip)
   289  	case net.IPv4len:
   290  		return table.IPv4.lookup(ip)
   291  	default:
   292  		panic(errors.New("looking up unknown address type"))
   293  	}
   294  }