github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/device/allowedips.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"container/list"
    10  	"errors"
    11  	"math/bits"
    12  	"net"
    13  	"sync"
    14  	"unsafe"
    15  )
    16  
    17  type parentIndirection struct {
    18  	parentBit     **trieEntry
    19  	parentBitType uint8
    20  }
    21  
    22  type trieEntry struct {
    23  	peer        *Peer
    24  	child       [2]*trieEntry
    25  	parent      parentIndirection
    26  	cidr        uint8
    27  	bitAtByte   uint8
    28  	bitAtShift  uint8
    29  	bits        net.IP
    30  	perPeerElem *list.Element
    31  }
    32  
    33  func isLittleEndian() bool {
    34  	one := uint32(1)
    35  	return *(*byte)(unsafe.Pointer(&one)) != 0
    36  }
    37  
    38  func swapU32(i uint32) uint32 {
    39  	if !isLittleEndian() {
    40  		return i
    41  	}
    42  
    43  	return bits.ReverseBytes32(i)
    44  }
    45  
    46  func swapU64(i uint64) uint64 {
    47  	if !isLittleEndian() {
    48  		return i
    49  	}
    50  
    51  	return bits.ReverseBytes64(i)
    52  }
    53  
    54  func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
    55  	size := len(ip1)
    56  	if size == net.IPv4len {
    57  		a := (*uint32)(unsafe.Pointer(&ip1[0]))
    58  		b := (*uint32)(unsafe.Pointer(&ip2[0]))
    59  		x := *a ^ *b
    60  		return uint8(bits.LeadingZeros32(swapU32(x)))
    61  	} else if size == net.IPv6len {
    62  		a := (*uint64)(unsafe.Pointer(&ip1[0]))
    63  		b := (*uint64)(unsafe.Pointer(&ip2[0]))
    64  		x := *a ^ *b
    65  		if x != 0 {
    66  			return uint8(bits.LeadingZeros64(swapU64(x)))
    67  		}
    68  		a = (*uint64)(unsafe.Pointer(&ip1[8]))
    69  		b = (*uint64)(unsafe.Pointer(&ip2[8]))
    70  		x = *a ^ *b
    71  		return 64 + uint8(bits.LeadingZeros64(swapU64(x)))
    72  	} else {
    73  		panic("Wrong size bit string")
    74  	}
    75  }
    76  
    77  func (node *trieEntry) addToPeerEntries() {
    78  	node.perPeerElem = node.peer.trieEntries.PushBack(node)
    79  }
    80  
    81  func (node *trieEntry) removeFromPeerEntries() {
    82  	if node.perPeerElem != nil {
    83  		node.peer.trieEntries.Remove(node.perPeerElem)
    84  		node.perPeerElem = nil
    85  	}
    86  }
    87  
    88  func (node *trieEntry) choose(ip net.IP) byte {
    89  	return (ip[node.bitAtByte] >> node.bitAtShift) & 1
    90  }
    91  
    92  func (node *trieEntry) maskSelf() {
    93  	mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
    94  	for i := 0; i < len(mask); i++ {
    95  		node.bits[i] &= mask[i]
    96  	}
    97  }
    98  
    99  func (node *trieEntry) zeroizePointers() {
   100  	// Make the garbage collector's life slightly easier
   101  	node.peer = nil
   102  	node.child[0] = nil
   103  	node.child[1] = nil
   104  	node.parent.parentBit = nil
   105  }
   106  
   107  func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
   108  	for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
   109  		parent = node
   110  		if parent.cidr == cidr {
   111  			exact = true
   112  			return
   113  		}
   114  		bit := node.choose(ip)
   115  		node = node.child[bit]
   116  	}
   117  	return
   118  }
   119  
   120  func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
   121  	if *trie.parentBit == nil {
   122  		node := &trieEntry{
   123  			peer:       peer,
   124  			parent:     trie,
   125  			bits:       ip,
   126  			cidr:       cidr,
   127  			bitAtByte:  cidr / 8,
   128  			bitAtShift: 7 - (cidr % 8),
   129  		}
   130  		node.maskSelf()
   131  		node.addToPeerEntries()
   132  		*trie.parentBit = node
   133  		return
   134  	}
   135  	node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
   136  	if exact {
   137  		node.removeFromPeerEntries()
   138  		node.peer = peer
   139  		node.addToPeerEntries()
   140  		return
   141  	}
   142  
   143  	newNode := &trieEntry{
   144  		peer:       peer,
   145  		bits:       ip,
   146  		cidr:       cidr,
   147  		bitAtByte:  cidr / 8,
   148  		bitAtShift: 7 - (cidr % 8),
   149  	}
   150  	newNode.maskSelf()
   151  	newNode.addToPeerEntries()
   152  
   153  	var down *trieEntry
   154  	if node == nil {
   155  		down = *trie.parentBit
   156  	} else {
   157  		bit := node.choose(ip)
   158  		down = node.child[bit]
   159  		if down == nil {
   160  			newNode.parent = parentIndirection{&node.child[bit], bit}
   161  			node.child[bit] = newNode
   162  			return
   163  		}
   164  	}
   165  	common := commonBits(down.bits, ip)
   166  	if common < cidr {
   167  		cidr = common
   168  	}
   169  	parent := node
   170  
   171  	if newNode.cidr == cidr {
   172  		bit := newNode.choose(down.bits)
   173  		down.parent = parentIndirection{&newNode.child[bit], bit}
   174  		newNode.child[bit] = down
   175  		if parent == nil {
   176  			newNode.parent = trie
   177  			*trie.parentBit = newNode
   178  		} else {
   179  			bit := parent.choose(newNode.bits)
   180  			newNode.parent = parentIndirection{&parent.child[bit], bit}
   181  			parent.child[bit] = newNode
   182  		}
   183  		return
   184  	}
   185  
   186  	node = &trieEntry{
   187  		bits:       append([]byte{}, newNode.bits...),
   188  		cidr:       cidr,
   189  		bitAtByte:  cidr / 8,
   190  		bitAtShift: 7 - (cidr % 8),
   191  	}
   192  	node.maskSelf()
   193  
   194  	bit := node.choose(down.bits)
   195  	down.parent = parentIndirection{&node.child[bit], bit}
   196  	node.child[bit] = down
   197  	bit = node.choose(newNode.bits)
   198  	newNode.parent = parentIndirection{&node.child[bit], bit}
   199  	node.child[bit] = newNode
   200  	if parent == nil {
   201  		node.parent = trie
   202  		*trie.parentBit = node
   203  	} else {
   204  		bit := parent.choose(node.bits)
   205  		node.parent = parentIndirection{&parent.child[bit], bit}
   206  		parent.child[bit] = node
   207  	}
   208  }
   209  
   210  func (node *trieEntry) lookup(ip net.IP) *Peer {
   211  	var found *Peer
   212  	size := uint8(len(ip))
   213  	for node != nil && commonBits(node.bits, ip) >= node.cidr {
   214  		if node.peer != nil {
   215  			found = node.peer
   216  		}
   217  		if node.bitAtByte == size {
   218  			break
   219  		}
   220  		bit := node.choose(ip)
   221  		node = node.child[bit]
   222  	}
   223  	return found
   224  }
   225  
   226  type AllowedIPs struct {
   227  	IPv4  *trieEntry
   228  	IPv6  *trieEntry
   229  	mutex sync.RWMutex
   230  }
   231  
   232  func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
   233  	table.mutex.RLock()
   234  	defer table.mutex.RUnlock()
   235  
   236  	for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
   237  		node := elem.Value.(*trieEntry)
   238  		if !cb(node.bits, node.cidr) {
   239  			return
   240  		}
   241  	}
   242  }
   243  
   244  func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
   245  	table.mutex.Lock()
   246  	defer table.mutex.Unlock()
   247  
   248  	var next *list.Element
   249  	for elem := peer.trieEntries.Front(); elem != nil; elem = next {
   250  		next = elem.Next()
   251  		node := elem.Value.(*trieEntry)
   252  
   253  		node.removeFromPeerEntries()
   254  		node.peer = nil
   255  		if node.child[0] != nil && node.child[1] != nil {
   256  			continue
   257  		}
   258  		bit := 0
   259  		if node.child[0] == nil {
   260  			bit = 1
   261  		}
   262  		child := node.child[bit]
   263  		if child != nil {
   264  			child.parent = node.parent
   265  		}
   266  		*node.parent.parentBit = child
   267  		if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
   268  			node.zeroizePointers()
   269  			continue
   270  		}
   271  		parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
   272  		if parent.peer != nil {
   273  			node.zeroizePointers()
   274  			continue
   275  		}
   276  		child = parent.child[node.parent.parentBitType^1]
   277  		if child != nil {
   278  			child.parent = parent.parent
   279  		}
   280  		*parent.parent.parentBit = child
   281  		node.zeroizePointers()
   282  		parent.zeroizePointers()
   283  	}
   284  }
   285  
   286  func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
   287  	table.mutex.Lock()
   288  	defer table.mutex.Unlock()
   289  
   290  	switch len(ip) {
   291  	case net.IPv6len:
   292  		parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
   293  	case net.IPv4len:
   294  		parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
   295  	default:
   296  		panic(errors.New("inserting unknown address type"))
   297  	}
   298  }
   299  
   300  func (table *AllowedIPs) Lookup(address []byte) *Peer {
   301  	table.mutex.RLock()
   302  	defer table.mutex.RUnlock()
   303  	switch len(address) {
   304  	case net.IPv6len:
   305  		return table.IPv6.lookup(address)
   306  	case net.IPv4len:
   307  		return table.IPv4.lookup(address)
   308  	default:
   309  		panic(errors.New("looking up unknown address type"))
   310  	}
   311  }