github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/allowedips.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"container/list"
    10  	"encoding/binary"
    11  	"errors"
    12  	"math/bits"
    13  	"net"
    14  	"sync"
    15  	"unsafe"
    16  
    17  	"golang.zx2c4.com/go118/netip"
    18  )
    19  
    20  type parentIndirection struct {
    21  	parentBit     **trieEntry
    22  	parentBitType uint8
    23  }
    24  
    25  type trieEntry struct {
    26  	peer        *Peer
    27  	child       [2]*trieEntry
    28  	parent      parentIndirection
    29  	cidr        uint8
    30  	bitAtByte   uint8
    31  	bitAtShift  uint8
    32  	bits        []byte
    33  	perPeerElem *list.Element
    34  }
    35  
    36  func commonBits(ip1, ip2 []byte) uint8 {
    37  	size := len(ip1)
    38  	if size == net.IPv4len {
    39  		a := binary.BigEndian.Uint32(ip1)
    40  		b := binary.BigEndian.Uint32(ip2)
    41  		x := a ^ b
    42  		return uint8(bits.LeadingZeros32(x))
    43  	} else if size == net.IPv6len {
    44  		a := binary.BigEndian.Uint64(ip1)
    45  		b := binary.BigEndian.Uint64(ip2)
    46  		x := a ^ b
    47  		if x != 0 {
    48  			return uint8(bits.LeadingZeros64(x))
    49  		}
    50  		a = binary.BigEndian.Uint64(ip1[8:])
    51  		b = binary.BigEndian.Uint64(ip2[8:])
    52  		x = a ^ b
    53  		return 64 + uint8(bits.LeadingZeros64(x))
    54  	} else {
    55  		panic("Wrong size bit string")
    56  	}
    57  }
    58  
    59  func (node *trieEntry) addToPeerEntries() {
    60  	node.perPeerElem = node.peer.trieEntries.PushBack(node)
    61  }
    62  
    63  func (node *trieEntry) removeFromPeerEntries() {
    64  	if node.perPeerElem != nil {
    65  		node.peer.trieEntries.Remove(node.perPeerElem)
    66  		node.perPeerElem = nil
    67  	}
    68  }
    69  
    70  func (node *trieEntry) choose(ip []byte) byte {
    71  	return (ip[node.bitAtByte] >> node.bitAtShift) & 1
    72  }
    73  
    74  func (node *trieEntry) maskSelf() {
    75  	mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
    76  	for i := 0; i < len(mask); i++ {
    77  		node.bits[i] &= mask[i]
    78  	}
    79  }
    80  
    81  func (node *trieEntry) zeroizePointers() {
    82  	// Make the garbage collector's life slightly easier
    83  	node.peer = nil
    84  	node.child[0] = nil
    85  	node.child[1] = nil
    86  	node.parent.parentBit = nil
    87  }
    88  
    89  func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
    90  	for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
    91  		parent = node
    92  		if parent.cidr == cidr {
    93  			exact = true
    94  			return
    95  		}
    96  		bit := node.choose(ip)
    97  		node = node.child[bit]
    98  	}
    99  	return
   100  }
   101  
   102  func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
   103  	if *trie.parentBit == nil {
   104  		node := &trieEntry{
   105  			peer:       peer,
   106  			parent:     trie,
   107  			bits:       ip,
   108  			cidr:       cidr,
   109  			bitAtByte:  cidr / 8,
   110  			bitAtShift: 7 - (cidr % 8),
   111  		}
   112  		node.maskSelf()
   113  		node.addToPeerEntries()
   114  		*trie.parentBit = node
   115  		return
   116  	}
   117  	node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
   118  	if exact {
   119  		node.removeFromPeerEntries()
   120  		node.peer = peer
   121  		node.addToPeerEntries()
   122  		return
   123  	}
   124  
   125  	newNode := &trieEntry{
   126  		peer:       peer,
   127  		bits:       ip,
   128  		cidr:       cidr,
   129  		bitAtByte:  cidr / 8,
   130  		bitAtShift: 7 - (cidr % 8),
   131  	}
   132  	newNode.maskSelf()
   133  	newNode.addToPeerEntries()
   134  
   135  	var down *trieEntry
   136  	if node == nil {
   137  		down = *trie.parentBit
   138  	} else {
   139  		bit := node.choose(ip)
   140  		down = node.child[bit]
   141  		if down == nil {
   142  			newNode.parent = parentIndirection{&node.child[bit], bit}
   143  			node.child[bit] = newNode
   144  			return
   145  		}
   146  	}
   147  	common := commonBits(down.bits, ip)
   148  	if common < cidr {
   149  		cidr = common
   150  	}
   151  	parent := node
   152  
   153  	if newNode.cidr == cidr {
   154  		bit := newNode.choose(down.bits)
   155  		down.parent = parentIndirection{&newNode.child[bit], bit}
   156  		newNode.child[bit] = down
   157  		if parent == nil {
   158  			newNode.parent = trie
   159  			*trie.parentBit = newNode
   160  		} else {
   161  			bit := parent.choose(newNode.bits)
   162  			newNode.parent = parentIndirection{&parent.child[bit], bit}
   163  			parent.child[bit] = newNode
   164  		}
   165  		return
   166  	}
   167  
   168  	node = &trieEntry{
   169  		bits:       append([]byte{}, newNode.bits...),
   170  		cidr:       cidr,
   171  		bitAtByte:  cidr / 8,
   172  		bitAtShift: 7 - (cidr % 8),
   173  	}
   174  	node.maskSelf()
   175  
   176  	bit := node.choose(down.bits)
   177  	down.parent = parentIndirection{&node.child[bit], bit}
   178  	node.child[bit] = down
   179  	bit = node.choose(newNode.bits)
   180  	newNode.parent = parentIndirection{&node.child[bit], bit}
   181  	node.child[bit] = newNode
   182  	if parent == nil {
   183  		node.parent = trie
   184  		*trie.parentBit = node
   185  	} else {
   186  		bit := parent.choose(node.bits)
   187  		node.parent = parentIndirection{&parent.child[bit], bit}
   188  		parent.child[bit] = node
   189  	}
   190  }
   191  
   192  func (node *trieEntry) lookup(ip []byte) *Peer {
   193  	var found *Peer
   194  	size := uint8(len(ip))
   195  	for node != nil && commonBits(node.bits, ip) >= node.cidr {
   196  		if node.peer != nil {
   197  			found = node.peer
   198  		}
   199  		if node.bitAtByte == size {
   200  			break
   201  		}
   202  		bit := node.choose(ip)
   203  		node = node.child[bit]
   204  	}
   205  	return found
   206  }
   207  
   208  type AllowedIPs struct {
   209  	IPv4  *trieEntry
   210  	IPv6  *trieEntry
   211  	mutex sync.RWMutex
   212  }
   213  
   214  func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
   215  	table.mutex.RLock()
   216  	defer table.mutex.RUnlock()
   217  
   218  	for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
   219  		node := elem.Value.(*trieEntry)
   220  		a, _ := netip.AddrFromSlice(node.bits)
   221  		if !cb(netip.PrefixFrom(a, int(node.cidr))) {
   222  			return
   223  		}
   224  	}
   225  }
   226  
   227  func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
   228  	table.mutex.Lock()
   229  	defer table.mutex.Unlock()
   230  
   231  	var next *list.Element
   232  	for elem := peer.trieEntries.Front(); elem != nil; elem = next {
   233  		next = elem.Next()
   234  		node := elem.Value.(*trieEntry)
   235  
   236  		node.removeFromPeerEntries()
   237  		node.peer = nil
   238  		if node.child[0] != nil && node.child[1] != nil {
   239  			continue
   240  		}
   241  		bit := 0
   242  		if node.child[0] == nil {
   243  			bit = 1
   244  		}
   245  		child := node.child[bit]
   246  		if child != nil {
   247  			child.parent = node.parent
   248  		}
   249  		*node.parent.parentBit = child
   250  		if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
   251  			node.zeroizePointers()
   252  			continue
   253  		}
   254  		parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
   255  		if parent.peer != nil {
   256  			node.zeroizePointers()
   257  			continue
   258  		}
   259  		child = parent.child[node.parent.parentBitType^1]
   260  		if child != nil {
   261  			child.parent = parent.parent
   262  		}
   263  		*parent.parent.parentBit = child
   264  		node.zeroizePointers()
   265  		parent.zeroizePointers()
   266  	}
   267  }
   268  
   269  func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
   270  	table.mutex.Lock()
   271  	defer table.mutex.Unlock()
   272  
   273  	if prefix.Addr().Is6() {
   274  		ip := prefix.Addr().As16()
   275  		parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
   276  	} else if prefix.Addr().Is4() {
   277  		ip := prefix.Addr().As4()
   278  		parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
   279  	} else {
   280  		panic(errors.New("inserting unknown address type"))
   281  	}
   282  }
   283  
   284  func (table *AllowedIPs) Lookup(ip []byte) *Peer {
   285  	table.mutex.RLock()
   286  	defer table.mutex.RUnlock()
   287  	switch len(ip) {
   288  	case net.IPv6len:
   289  		return table.IPv6.lookup(ip)
   290  	case net.IPv4len:
   291  		return table.IPv4.lookup(ip)
   292  	default:
   293  		panic(errors.New("looking up unknown address type"))
   294  	}
   295  }