github.com/amnezia-vpn/amnezia-wg@v0.1.8/device/allowedips_test.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"math/rand"
    10  	"net"
    11  	"net/netip"
    12  	"testing"
    13  )
    14  
    15  type testPairCommonBits struct {
    16  	s1    []byte
    17  	s2    []byte
    18  	match uint8
    19  }
    20  
    21  func TestCommonBits(t *testing.T) {
    22  	tests := []testPairCommonBits{
    23  		{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
    24  		{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
    25  		{s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
    26  		{s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
    27  		{s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
    28  	}
    29  
    30  	for _, p := range tests {
    31  		v := commonBits(p.s1, p.s2)
    32  		if v != p.match {
    33  			t.Error(
    34  				"For slice", p.s1, p.s2,
    35  				"expected match", p.match,
    36  				",but got", v,
    37  			)
    38  		}
    39  	}
    40  }
    41  
    42  func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
    43  	var trie *trieEntry
    44  	var peers []*Peer
    45  	root := parentIndirection{&trie, 2}
    46  
    47  	rand.Seed(1)
    48  
    49  	const AddressLength = 4
    50  
    51  	for n := 0; n < peerNumber; n++ {
    52  		peers = append(peers, &Peer{})
    53  	}
    54  
    55  	for n := 0; n < addressNumber; n++ {
    56  		var addr [AddressLength]byte
    57  		rand.Read(addr[:])
    58  		cidr := uint8(rand.Uint32() % (AddressLength * 8))
    59  		index := rand.Int() % peerNumber
    60  		root.insert(addr[:], cidr, peers[index])
    61  	}
    62  
    63  	for n := 0; n < b.N; n++ {
    64  		var addr [AddressLength]byte
    65  		rand.Read(addr[:])
    66  		trie.lookup(addr[:])
    67  	}
    68  }
    69  
    70  func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
    71  	benchmarkTrie(100, 1000, net.IPv4len, b)
    72  }
    73  
    74  func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
    75  	benchmarkTrie(10, 10, net.IPv4len, b)
    76  }
    77  
    78  func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
    79  	benchmarkTrie(100, 1000, net.IPv6len, b)
    80  }
    81  
    82  func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
    83  	benchmarkTrie(10, 10, net.IPv6len, b)
    84  }
    85  
    86  /* Test ported from kernel implementation:
    87   * selftest/allowedips.h
    88   */
    89  func TestTrieIPv4(t *testing.T) {
    90  	a := &Peer{}
    91  	b := &Peer{}
    92  	c := &Peer{}
    93  	d := &Peer{}
    94  	e := &Peer{}
    95  	g := &Peer{}
    96  	h := &Peer{}
    97  
    98  	var allowedIPs AllowedIPs
    99  
   100  	insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
   101  		allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
   102  	}
   103  
   104  	assertEQ := func(peer *Peer, a, b, c, d byte) {
   105  		p := allowedIPs.Lookup([]byte{a, b, c, d})
   106  		if p != peer {
   107  			t.Error("Assert EQ failed")
   108  		}
   109  	}
   110  
   111  	assertNEQ := func(peer *Peer, a, b, c, d byte) {
   112  		p := allowedIPs.Lookup([]byte{a, b, c, d})
   113  		if p == peer {
   114  			t.Error("Assert NEQ failed")
   115  		}
   116  	}
   117  
   118  	insert(a, 192, 168, 4, 0, 24)
   119  	insert(b, 192, 168, 4, 4, 32)
   120  	insert(c, 192, 168, 0, 0, 16)
   121  	insert(d, 192, 95, 5, 64, 27)
   122  	insert(c, 192, 95, 5, 65, 27)
   123  	insert(e, 0, 0, 0, 0, 0)
   124  	insert(g, 64, 15, 112, 0, 20)
   125  	insert(h, 64, 15, 123, 211, 25)
   126  	insert(a, 10, 0, 0, 0, 25)
   127  	insert(b, 10, 0, 0, 128, 25)
   128  	insert(a, 10, 1, 0, 0, 30)
   129  	insert(b, 10, 1, 0, 4, 30)
   130  	insert(c, 10, 1, 0, 8, 29)
   131  	insert(d, 10, 1, 0, 16, 29)
   132  
   133  	assertEQ(a, 192, 168, 4, 20)
   134  	assertEQ(a, 192, 168, 4, 0)
   135  	assertEQ(b, 192, 168, 4, 4)
   136  	assertEQ(c, 192, 168, 200, 182)
   137  	assertEQ(c, 192, 95, 5, 68)
   138  	assertEQ(e, 192, 95, 5, 96)
   139  	assertEQ(g, 64, 15, 116, 26)
   140  	assertEQ(g, 64, 15, 127, 3)
   141  
   142  	insert(a, 1, 0, 0, 0, 32)
   143  	insert(a, 64, 0, 0, 0, 32)
   144  	insert(a, 128, 0, 0, 0, 32)
   145  	insert(a, 192, 0, 0, 0, 32)
   146  	insert(a, 255, 0, 0, 0, 32)
   147  
   148  	assertEQ(a, 1, 0, 0, 0)
   149  	assertEQ(a, 64, 0, 0, 0)
   150  	assertEQ(a, 128, 0, 0, 0)
   151  	assertEQ(a, 192, 0, 0, 0)
   152  	assertEQ(a, 255, 0, 0, 0)
   153  
   154  	allowedIPs.RemoveByPeer(a)
   155  
   156  	assertNEQ(a, 1, 0, 0, 0)
   157  	assertNEQ(a, 64, 0, 0, 0)
   158  	assertNEQ(a, 128, 0, 0, 0)
   159  	assertNEQ(a, 192, 0, 0, 0)
   160  	assertNEQ(a, 255, 0, 0, 0)
   161  
   162  	allowedIPs.RemoveByPeer(a)
   163  	allowedIPs.RemoveByPeer(b)
   164  	allowedIPs.RemoveByPeer(c)
   165  	allowedIPs.RemoveByPeer(d)
   166  	allowedIPs.RemoveByPeer(e)
   167  	allowedIPs.RemoveByPeer(g)
   168  	allowedIPs.RemoveByPeer(h)
   169  	if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
   170  		t.Error("Expected removing all the peers to empty trie, but it did not")
   171  	}
   172  
   173  	insert(a, 192, 168, 0, 0, 16)
   174  	insert(a, 192, 168, 0, 0, 24)
   175  
   176  	allowedIPs.RemoveByPeer(a)
   177  
   178  	assertNEQ(a, 192, 168, 0, 1)
   179  }
   180  
   181  /* Test ported from kernel implementation:
   182   * selftest/allowedips.h
   183   */
   184  func TestTrieIPv6(t *testing.T) {
   185  	a := &Peer{}
   186  	b := &Peer{}
   187  	c := &Peer{}
   188  	d := &Peer{}
   189  	e := &Peer{}
   190  	f := &Peer{}
   191  	g := &Peer{}
   192  	h := &Peer{}
   193  
   194  	var allowedIPs AllowedIPs
   195  
   196  	expand := func(a uint32) []byte {
   197  		var out [4]byte
   198  		out[0] = byte(a >> 24 & 0xff)
   199  		out[1] = byte(a >> 16 & 0xff)
   200  		out[2] = byte(a >> 8 & 0xff)
   201  		out[3] = byte(a & 0xff)
   202  		return out[:]
   203  	}
   204  
   205  	insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
   206  		var addr []byte
   207  		addr = append(addr, expand(a)...)
   208  		addr = append(addr, expand(b)...)
   209  		addr = append(addr, expand(c)...)
   210  		addr = append(addr, expand(d)...)
   211  		allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
   212  	}
   213  
   214  	assertEQ := func(peer *Peer, a, b, c, d uint32) {
   215  		var addr []byte
   216  		addr = append(addr, expand(a)...)
   217  		addr = append(addr, expand(b)...)
   218  		addr = append(addr, expand(c)...)
   219  		addr = append(addr, expand(d)...)
   220  		p := allowedIPs.Lookup(addr)
   221  		if p != peer {
   222  			t.Error("Assert EQ failed")
   223  		}
   224  	}
   225  
   226  	insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
   227  	insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
   228  	insert(e, 0, 0, 0, 0, 0)
   229  	insert(f, 0, 0, 0, 0, 0)
   230  	insert(g, 0x24046800, 0, 0, 0, 32)
   231  	insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64)
   232  	insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128)
   233  	insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
   234  	insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
   235  
   236  	assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543)
   237  	assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee)
   238  	assertEQ(f, 0x26075300, 0x60006b01, 0, 0)
   239  	assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006)
   240  	assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678)
   241  	assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678)
   242  	assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678)
   243  	assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678)
   244  	assertEQ(h, 0x24046800, 0x40040800, 0, 0)
   245  	assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
   246  	assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
   247  }