github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/device/allowedips.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"container/list"
    10  	"errors"
    11  	"math/bits"
    12  	"net"
    13  	"sync"
    14  	"unsafe"
    15  )
    16  
    17  type trieEntry struct {
    18  	child        [2]*trieEntry
    19  	peer         *Peer
    20  	bits         net.IP
    21  	cidr         uint
    22  	bit_at_byte  uint
    23  	bit_at_shift uint
    24  	perPeerElem  *list.Element
    25  }
    26  
    27  func isLittleEndian() bool {
    28  	one := uint32(1)
    29  	return *(*byte)(unsafe.Pointer(&one)) != 0
    30  }
    31  
    32  func swapU32(i uint32) uint32 {
    33  	if !isLittleEndian() {
    34  		return i
    35  	}
    36  
    37  	return bits.ReverseBytes32(i)
    38  }
    39  
    40  func swapU64(i uint64) uint64 {
    41  	if !isLittleEndian() {
    42  		return i
    43  	}
    44  
    45  	return bits.ReverseBytes64(i)
    46  }
    47  
    48  func commonBits(ip1 net.IP, ip2 net.IP) uint {
    49  	size := len(ip1)
    50  	if size == net.IPv4len {
    51  		a := (*uint32)(unsafe.Pointer(&ip1[0]))
    52  		b := (*uint32)(unsafe.Pointer(&ip2[0]))
    53  		x := *a ^ *b
    54  		return uint(bits.LeadingZeros32(swapU32(x)))
    55  	} else if size == net.IPv6len {
    56  		a := (*uint64)(unsafe.Pointer(&ip1[0]))
    57  		b := (*uint64)(unsafe.Pointer(&ip2[0]))
    58  		x := *a ^ *b
    59  		if x != 0 {
    60  			return uint(bits.LeadingZeros64(swapU64(x)))
    61  		}
    62  		a = (*uint64)(unsafe.Pointer(&ip1[8]))
    63  		b = (*uint64)(unsafe.Pointer(&ip2[8]))
    64  		x = *a ^ *b
    65  		return 64 + uint(bits.LeadingZeros64(swapU64(x)))
    66  	} else {
    67  		panic("Wrong size bit string")
    68  	}
    69  }
    70  
    71  func (node *trieEntry) addToPeerEntries() {
    72  	node.perPeerElem = node.peer.trieEntries.PushBack(node)
    73  }
    74  
    75  func (node *trieEntry) removeFromPeerEntries() {
    76  	if node.perPeerElem != nil {
    77  		node.peer.trieEntries.Remove(node.perPeerElem)
    78  		node.perPeerElem = nil
    79  	}
    80  }
    81  
    82  func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
    83  	if node == nil {
    84  		return node
    85  	}
    86  
    87  	// walk recursively
    88  
    89  	node.child[0] = node.child[0].removeByPeer(p)
    90  	node.child[1] = node.child[1].removeByPeer(p)
    91  
    92  	if node.peer != p {
    93  		return node
    94  	}
    95  
    96  	// remove peer & merge
    97  
    98  	node.removeFromPeerEntries()
    99  	node.peer = nil
   100  	if node.child[0] == nil {
   101  		return node.child[1]
   102  	}
   103  	return node.child[0]
   104  }
   105  
   106  func (node *trieEntry) choose(ip net.IP) byte {
   107  	return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
   108  }
   109  
   110  func (node *trieEntry) maskSelf() {
   111  	mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
   112  	for i := 0; i < len(mask); i++ {
   113  		node.bits[i] &= mask[i]
   114  	}
   115  }
   116  
   117  func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
   118  
   119  	// at leaf
   120  
   121  	if node == nil {
   122  		node := &trieEntry{
   123  			bits:         ip,
   124  			peer:         peer,
   125  			cidr:         cidr,
   126  			bit_at_byte:  cidr / 8,
   127  			bit_at_shift: 7 - (cidr % 8),
   128  		}
   129  		node.maskSelf()
   130  		node.addToPeerEntries()
   131  		return node
   132  	}
   133  
   134  	// traverse deeper
   135  
   136  	common := commonBits(node.bits, ip)
   137  	if node.cidr <= cidr && common >= node.cidr {
   138  		if node.cidr == cidr {
   139  			node.removeFromPeerEntries()
   140  			node.peer = peer
   141  			node.addToPeerEntries()
   142  			return node
   143  		}
   144  		bit := node.choose(ip)
   145  		node.child[bit] = node.child[bit].insert(ip, cidr, peer)
   146  		return node
   147  	}
   148  
   149  	// split node
   150  
   151  	newNode := &trieEntry{
   152  		bits:         ip,
   153  		peer:         peer,
   154  		cidr:         cidr,
   155  		bit_at_byte:  cidr / 8,
   156  		bit_at_shift: 7 - (cidr % 8),
   157  	}
   158  	newNode.maskSelf()
   159  	newNode.addToPeerEntries()
   160  
   161  	cidr = min(cidr, common)
   162  
   163  	// check for shorter prefix
   164  
   165  	if newNode.cidr == cidr {
   166  		bit := newNode.choose(node.bits)
   167  		newNode.child[bit] = node
   168  		return newNode
   169  	}
   170  
   171  	// create new parent for node & newNode
   172  
   173  	parent := &trieEntry{
   174  		bits:         append([]byte{}, ip...),
   175  		peer:         nil,
   176  		cidr:         cidr,
   177  		bit_at_byte:  cidr / 8,
   178  		bit_at_shift: 7 - (cidr % 8),
   179  	}
   180  	parent.maskSelf()
   181  
   182  	bit := parent.choose(ip)
   183  	parent.child[bit] = newNode
   184  	parent.child[bit^1] = node
   185  
   186  	return parent
   187  }
   188  
   189  func (node *trieEntry) lookup(ip net.IP) *Peer {
   190  	var found *Peer
   191  	size := uint(len(ip))
   192  	for node != nil && commonBits(node.bits, ip) >= node.cidr {
   193  		if node.peer != nil {
   194  			found = node.peer
   195  		}
   196  		if node.bit_at_byte == size {
   197  			break
   198  		}
   199  		bit := node.choose(ip)
   200  		node = node.child[bit]
   201  	}
   202  	return found
   203  }
   204  
   205  type AllowedIPs struct {
   206  	IPv4  *trieEntry
   207  	IPv6  *trieEntry
   208  	mutex sync.RWMutex
   209  }
   210  
   211  func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
   212  	table.mutex.RLock()
   213  	defer table.mutex.RUnlock()
   214  
   215  	for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
   216  		node := elem.Value.(*trieEntry)
   217  		if !cb(node.bits, node.cidr) {
   218  			return
   219  		}
   220  	}
   221  }
   222  
   223  func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
   224  	table.mutex.Lock()
   225  	defer table.mutex.Unlock()
   226  
   227  	table.IPv4 = table.IPv4.removeByPeer(peer)
   228  	table.IPv6 = table.IPv6.removeByPeer(peer)
   229  }
   230  
   231  func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
   232  	table.mutex.Lock()
   233  	defer table.mutex.Unlock()
   234  
   235  	switch len(ip) {
   236  	case net.IPv6len:
   237  		table.IPv6 = table.IPv6.insert(ip, cidr, peer)
   238  	case net.IPv4len:
   239  		table.IPv4 = table.IPv4.insert(ip, cidr, peer)
   240  	default:
   241  		panic(errors.New("inserting unknown address type"))
   242  	}
   243  }
   244  
   245  func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
   246  	table.mutex.RLock()
   247  	defer table.mutex.RUnlock()
   248  	return table.IPv4.lookup(address)
   249  }
   250  
   251  func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
   252  	table.mutex.RLock()
   253  	defer table.mutex.RUnlock()
   254  	return table.IPv6.lookup(address)
   255  }