github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/device/allowedips_rand_test.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "math/rand" 10 "net" 11 "net/netip" 12 "sort" 13 "testing" 14 ) 15 16 const ( 17 NumberOfPeers = 100 18 NumberOfPeerRemovals = 4 19 NumberOfAddresses = 250 20 NumberOfTests = 10000 21 ) 22 23 type SlowNode struct { 24 peer *Peer 25 cidr uint8 26 bits []byte 27 } 28 29 type SlowRouter []*SlowNode 30 31 func (r SlowRouter) Len() int { 32 return len(r) 33 } 34 35 func (r SlowRouter) Less(i, j int) bool { 36 return r[i].cidr > r[j].cidr 37 } 38 39 func (r SlowRouter) Swap(i, j int) { 40 r[i], r[j] = r[j], r[i] 41 } 42 43 func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { 44 for _, t := range r { 45 if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { 46 t.peer = peer 47 t.bits = addr 48 return r 49 } 50 } 51 r = append(r, &SlowNode{ 52 cidr: cidr, 53 bits: addr, 54 peer: peer, 55 }) 56 sort.Sort(r) 57 return r 58 } 59 60 func (r SlowRouter) Lookup(addr []byte) *Peer { 61 for _, t := range r { 62 common := commonBits(t.bits, addr) 63 if common >= t.cidr { 64 return t.peer 65 } 66 } 67 return nil 68 } 69 70 func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { 71 n := 0 72 for _, x := range r { 73 if x.peer != peer { 74 r[n] = x 75 n++ 76 } 77 } 78 return r[:n] 79 } 80 81 func TestTrieRandom(t *testing.T) { 82 var slow4, slow6 SlowRouter 83 var peers []*Peer 84 var allowedIPs AllowedIPs 85 86 rand.Seed(1) 87 88 for n := 0; n < NumberOfPeers; n++ { 89 peers = append(peers, &Peer{}) 90 } 91 92 for n := 0; n < NumberOfAddresses; n++ { 93 var addr4 [4]byte 94 rand.Read(addr4[:]) 95 cidr := uint8(rand.Intn(32) + 1) 96 index := rand.Intn(NumberOfPeers) 97 allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) 98 slow4 = slow4.Insert(addr4[:], cidr, peers[index]) 99 100 var addr6 [16]byte 101 rand.Read(addr6[:]) 102 cidr = uint8(rand.Intn(128) + 1) 103 index = rand.Intn(NumberOfPeers) 104 allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) 105 slow6 = slow6.Insert(addr6[:], cidr, peers[index]) 106 } 107 108 var p int 109 for p = 0; ; p++ { 110 for n := 0; n < NumberOfTests; n++ { 111 var addr4 [4]byte 112 rand.Read(addr4[:]) 113 peer1 := slow4.Lookup(addr4[:]) 114 peer2 := allowedIPs.Lookup(addr4[:]) 115 if peer1 != peer2 { 116 t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) 117 } 118 119 var addr6 [16]byte 120 rand.Read(addr6[:]) 121 peer1 = slow6.Lookup(addr6[:]) 122 peer2 = allowedIPs.Lookup(addr6[:]) 123 if peer1 != peer2 { 124 t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) 125 } 126 } 127 if p >= len(peers) || p >= NumberOfPeerRemovals { 128 break 129 } 130 allowedIPs.RemoveByPeer(peers[p]) 131 slow4 = slow4.RemoveByPeer(peers[p]) 132 slow6 = slow6.RemoveByPeer(peers[p]) 133 } 134 for ; p < len(peers); p++ { 135 allowedIPs.RemoveByPeer(peers[p]) 136 } 137 138 if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 139 t.Error("Failed to remove all nodes from trie by peer") 140 } 141 }