github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/allowedips_test.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"math/rand"
    10  	"net"
    11  	"testing"
    12  
    13  	"golang.zx2c4.com/go118/netip"
    14  )
    15  
    16  type testPairCommonBits struct {
    17  	s1    []byte
    18  	s2    []byte
    19  	match uint8
    20  }
    21  
    22  func TestCommonBits(t *testing.T) {
    23  	tests := []testPairCommonBits{
    24  		{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
    25  		{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
    26  		{s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
    27  		{s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
    28  		{s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
    29  	}
    30  
    31  	for _, p := range tests {
    32  		v := commonBits(p.s1, p.s2)
    33  		if v != p.match {
    34  			t.Error(
    35  				"For slice", p.s1, p.s2,
    36  				"expected match", p.match,
    37  				",but got", v,
    38  			)
    39  		}
    40  	}
    41  }
    42  
    43  func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
    44  	var trie *trieEntry
    45  	var peers []*Peer
    46  	root := parentIndirection{&trie, 2}
    47  
    48  	rand.Seed(1)
    49  
    50  	const AddressLength = 4
    51  
    52  	for n := 0; n < peerNumber; n++ {
    53  		peers = append(peers, &Peer{})
    54  	}
    55  
    56  	for n := 0; n < addressNumber; n++ {
    57  		var addr [AddressLength]byte
    58  		rand.Read(addr[:])
    59  		cidr := uint8(rand.Uint32() % (AddressLength * 8))
    60  		index := rand.Int() % peerNumber
    61  		root.insert(addr[:], cidr, peers[index])
    62  	}
    63  
    64  	for n := 0; n < b.N; n++ {
    65  		var addr [AddressLength]byte
    66  		rand.Read(addr[:])
    67  		trie.lookup(addr[:])
    68  	}
    69  }
    70  
    71  func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
    72  	benchmarkTrie(100, 1000, net.IPv4len, b)
    73  }
    74  
    75  func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
    76  	benchmarkTrie(10, 10, net.IPv4len, b)
    77  }
    78  
    79  func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
    80  	benchmarkTrie(100, 1000, net.IPv6len, b)
    81  }
    82  
    83  func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
    84  	benchmarkTrie(10, 10, net.IPv6len, b)
    85  }
    86  
    87  /* Test ported from kernel implementation:
    88   * selftest/allowedips.h
    89   */
    90  func TestTrieIPv4(t *testing.T) {
    91  	a := &Peer{}
    92  	b := &Peer{}
    93  	c := &Peer{}
    94  	d := &Peer{}
    95  	e := &Peer{}
    96  	g := &Peer{}
    97  	h := &Peer{}
    98  
    99  	var allowedIPs AllowedIPs
   100  
   101  	insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
   102  		allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
   103  	}
   104  
   105  	assertEQ := func(peer *Peer, a, b, c, d byte) {
   106  		p := allowedIPs.Lookup([]byte{a, b, c, d})
   107  		if p != peer {
   108  			t.Error("Assert EQ failed")
   109  		}
   110  	}
   111  
   112  	assertNEQ := func(peer *Peer, a, b, c, d byte) {
   113  		p := allowedIPs.Lookup([]byte{a, b, c, d})
   114  		if p == peer {
   115  			t.Error("Assert NEQ failed")
   116  		}
   117  	}
   118  
   119  	insert(a, 192, 168, 4, 0, 24)
   120  	insert(b, 192, 168, 4, 4, 32)
   121  	insert(c, 192, 168, 0, 0, 16)
   122  	insert(d, 192, 95, 5, 64, 27)
   123  	insert(c, 192, 95, 5, 65, 27)
   124  	insert(e, 0, 0, 0, 0, 0)
   125  	insert(g, 64, 15, 112, 0, 20)
   126  	insert(h, 64, 15, 123, 211, 25)
   127  	insert(a, 10, 0, 0, 0, 25)
   128  	insert(b, 10, 0, 0, 128, 25)
   129  	insert(a, 10, 1, 0, 0, 30)
   130  	insert(b, 10, 1, 0, 4, 30)
   131  	insert(c, 10, 1, 0, 8, 29)
   132  	insert(d, 10, 1, 0, 16, 29)
   133  
   134  	assertEQ(a, 192, 168, 4, 20)
   135  	assertEQ(a, 192, 168, 4, 0)
   136  	assertEQ(b, 192, 168, 4, 4)
   137  	assertEQ(c, 192, 168, 200, 182)
   138  	assertEQ(c, 192, 95, 5, 68)
   139  	assertEQ(e, 192, 95, 5, 96)
   140  	assertEQ(g, 64, 15, 116, 26)
   141  	assertEQ(g, 64, 15, 127, 3)
   142  
   143  	insert(a, 1, 0, 0, 0, 32)
   144  	insert(a, 64, 0, 0, 0, 32)
   145  	insert(a, 128, 0, 0, 0, 32)
   146  	insert(a, 192, 0, 0, 0, 32)
   147  	insert(a, 255, 0, 0, 0, 32)
   148  
   149  	assertEQ(a, 1, 0, 0, 0)
   150  	assertEQ(a, 64, 0, 0, 0)
   151  	assertEQ(a, 128, 0, 0, 0)
   152  	assertEQ(a, 192, 0, 0, 0)
   153  	assertEQ(a, 255, 0, 0, 0)
   154  
   155  	allowedIPs.RemoveByPeer(a)
   156  
   157  	assertNEQ(a, 1, 0, 0, 0)
   158  	assertNEQ(a, 64, 0, 0, 0)
   159  	assertNEQ(a, 128, 0, 0, 0)
   160  	assertNEQ(a, 192, 0, 0, 0)
   161  	assertNEQ(a, 255, 0, 0, 0)
   162  
   163  	allowedIPs.RemoveByPeer(a)
   164  	allowedIPs.RemoveByPeer(b)
   165  	allowedIPs.RemoveByPeer(c)
   166  	allowedIPs.RemoveByPeer(d)
   167  	allowedIPs.RemoveByPeer(e)
   168  	allowedIPs.RemoveByPeer(g)
   169  	allowedIPs.RemoveByPeer(h)
   170  	if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
   171  		t.Error("Expected removing all the peers to empty trie, but it did not")
   172  	}
   173  
   174  	insert(a, 192, 168, 0, 0, 16)
   175  	insert(a, 192, 168, 0, 0, 24)
   176  
   177  	allowedIPs.RemoveByPeer(a)
   178  
   179  	assertNEQ(a, 192, 168, 0, 1)
   180  }
   181  
   182  /* Test ported from kernel implementation:
   183   * selftest/allowedips.h
   184   */
   185  func TestTrieIPv6(t *testing.T) {
   186  	a := &Peer{}
   187  	b := &Peer{}
   188  	c := &Peer{}
   189  	d := &Peer{}
   190  	e := &Peer{}
   191  	f := &Peer{}
   192  	g := &Peer{}
   193  	h := &Peer{}
   194  
   195  	var allowedIPs AllowedIPs
   196  
   197  	expand := func(a uint32) []byte {
   198  		var out [4]byte
   199  		out[0] = byte(a >> 24 & 0xff)
   200  		out[1] = byte(a >> 16 & 0xff)
   201  		out[2] = byte(a >> 8 & 0xff)
   202  		out[3] = byte(a & 0xff)
   203  		return out[:]
   204  	}
   205  
   206  	insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
   207  		var addr []byte
   208  		addr = append(addr, expand(a)...)
   209  		addr = append(addr, expand(b)...)
   210  		addr = append(addr, expand(c)...)
   211  		addr = append(addr, expand(d)...)
   212  		allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
   213  	}
   214  
   215  	assertEQ := func(peer *Peer, a, b, c, d uint32) {
   216  		var addr []byte
   217  		addr = append(addr, expand(a)...)
   218  		addr = append(addr, expand(b)...)
   219  		addr = append(addr, expand(c)...)
   220  		addr = append(addr, expand(d)...)
   221  		p := allowedIPs.Lookup(addr)
   222  		if p != peer {
   223  			t.Error("Assert EQ failed")
   224  		}
   225  	}
   226  
   227  	insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
   228  	insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
   229  	insert(e, 0, 0, 0, 0, 0)
   230  	insert(f, 0, 0, 0, 0, 0)
   231  	insert(g, 0x24046800, 0, 0, 0, 32)
   232  	insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64)
   233  	insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128)
   234  	insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
   235  	insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
   236  
   237  	assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543)
   238  	assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee)
   239  	assertEQ(f, 0x26075300, 0x60006b01, 0, 0)
   240  	assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006)
   241  	assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678)
   242  	assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678)
   243  	assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678)
   244  	assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678)
   245  	assertEQ(h, 0x24046800, 0x40040800, 0, 0)
   246  	assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
   247  	assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
   248  }