github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/allowedips_rand_test.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"math/rand"
    10  	"net"
    11  	"sort"
    12  	"testing"
    13  
    14  	"golang.zx2c4.com/go118/netip"
    15  )
    16  
    17  const (
    18  	NumberOfPeers        = 100
    19  	NumberOfPeerRemovals = 4
    20  	NumberOfAddresses    = 250
    21  	NumberOfTests        = 10000
    22  )
    23  
    24  type SlowNode struct {
    25  	peer *Peer
    26  	cidr uint8
    27  	bits []byte
    28  }
    29  
    30  type SlowRouter []*SlowNode
    31  
    32  func (r SlowRouter) Len() int {
    33  	return len(r)
    34  }
    35  
    36  func (r SlowRouter) Less(i, j int) bool {
    37  	return r[i].cidr > r[j].cidr
    38  }
    39  
    40  func (r SlowRouter) Swap(i, j int) {
    41  	r[i], r[j] = r[j], r[i]
    42  }
    43  
    44  func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
    45  	for _, t := range r {
    46  		if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
    47  			t.peer = peer
    48  			t.bits = addr
    49  			return r
    50  		}
    51  	}
    52  	r = append(r, &SlowNode{
    53  		cidr: cidr,
    54  		bits: addr,
    55  		peer: peer,
    56  	})
    57  	sort.Sort(r)
    58  	return r
    59  }
    60  
    61  func (r SlowRouter) Lookup(addr []byte) *Peer {
    62  	for _, t := range r {
    63  		common := commonBits(t.bits, addr)
    64  		if common >= t.cidr {
    65  			return t.peer
    66  		}
    67  	}
    68  	return nil
    69  }
    70  
    71  func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
    72  	n := 0
    73  	for _, x := range r {
    74  		if x.peer != peer {
    75  			r[n] = x
    76  			n++
    77  		}
    78  	}
    79  	return r[:n]
    80  }
    81  
    82  func TestTrieRandom(t *testing.T) {
    83  	var slow4, slow6 SlowRouter
    84  	var peers []*Peer
    85  	var allowedIPs AllowedIPs
    86  
    87  	rand.Seed(1)
    88  
    89  	for n := 0; n < NumberOfPeers; n++ {
    90  		peers = append(peers, &Peer{})
    91  	}
    92  
    93  	for n := 0; n < NumberOfAddresses; n++ {
    94  		var addr4 [4]byte
    95  		rand.Read(addr4[:])
    96  		cidr := uint8(rand.Intn(32) + 1)
    97  		index := rand.Intn(NumberOfPeers)
    98  		allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
    99  		slow4 = slow4.Insert(addr4[:], cidr, peers[index])
   100  
   101  		var addr6 [16]byte
   102  		rand.Read(addr6[:])
   103  		cidr = uint8(rand.Intn(128) + 1)
   104  		index = rand.Intn(NumberOfPeers)
   105  		allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
   106  		slow6 = slow6.Insert(addr6[:], cidr, peers[index])
   107  	}
   108  
   109  	var p int
   110  	for p = 0; ; p++ {
   111  		for n := 0; n < NumberOfTests; n++ {
   112  			var addr4 [4]byte
   113  			rand.Read(addr4[:])
   114  			peer1 := slow4.Lookup(addr4[:])
   115  			peer2 := allowedIPs.Lookup(addr4[:])
   116  			if peer1 != peer2 {
   117  				t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
   118  			}
   119  
   120  			var addr6 [16]byte
   121  			rand.Read(addr6[:])
   122  			peer1 = slow6.Lookup(addr6[:])
   123  			peer2 = allowedIPs.Lookup(addr6[:])
   124  			if peer1 != peer2 {
   125  				t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
   126  			}
   127  		}
   128  		if p >= len(peers) || p >= NumberOfPeerRemovals {
   129  			break
   130  		}
   131  		allowedIPs.RemoveByPeer(peers[p])
   132  		slow4 = slow4.RemoveByPeer(peers[p])
   133  		slow6 = slow6.RemoveByPeer(peers[p])
   134  	}
   135  	for ; p < len(peers); p++ {
   136  		allowedIPs.RemoveByPeer(peers[p])
   137  	}
   138  
   139  	if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
   140  		t.Error("Failed to remove all nodes from trie by peer")
   141  	}
   142  }