github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/allowedips_rand_test.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "math/rand" 10 "net" 11 "sort" 12 "testing" 13 14 "golang.zx2c4.com/go118/netip" 15 ) 16 17 const ( 18 NumberOfPeers = 100 19 NumberOfPeerRemovals = 4 20 NumberOfAddresses = 250 21 NumberOfTests = 10000 22 ) 23 24 type SlowNode struct { 25 peer *Peer 26 cidr uint8 27 bits []byte 28 } 29 30 type SlowRouter []*SlowNode 31 32 func (r SlowRouter) Len() int { 33 return len(r) 34 } 35 36 func (r SlowRouter) Less(i, j int) bool { 37 return r[i].cidr > r[j].cidr 38 } 39 40 func (r SlowRouter) Swap(i, j int) { 41 r[i], r[j] = r[j], r[i] 42 } 43 44 func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { 45 for _, t := range r { 46 if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { 47 t.peer = peer 48 t.bits = addr 49 return r 50 } 51 } 52 r = append(r, &SlowNode{ 53 cidr: cidr, 54 bits: addr, 55 peer: peer, 56 }) 57 sort.Sort(r) 58 return r 59 } 60 61 func (r SlowRouter) Lookup(addr []byte) *Peer { 62 for _, t := range r { 63 common := commonBits(t.bits, addr) 64 if common >= t.cidr { 65 return t.peer 66 } 67 } 68 return nil 69 } 70 71 func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { 72 n := 0 73 for _, x := range r { 74 if x.peer != peer { 75 r[n] = x 76 n++ 77 } 78 } 79 return r[:n] 80 } 81 82 func TestTrieRandom(t *testing.T) { 83 var slow4, slow6 SlowRouter 84 var peers []*Peer 85 var allowedIPs AllowedIPs 86 87 rand.Seed(1) 88 89 for n := 0; n < NumberOfPeers; n++ { 90 peers = append(peers, &Peer{}) 91 } 92 93 for n := 0; n < NumberOfAddresses; n++ { 94 var addr4 [4]byte 95 rand.Read(addr4[:]) 96 cidr := uint8(rand.Intn(32) + 1) 97 index := rand.Intn(NumberOfPeers) 98 allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) 99 slow4 = slow4.Insert(addr4[:], cidr, peers[index]) 100 101 var addr6 [16]byte 102 rand.Read(addr6[:]) 103 cidr = uint8(rand.Intn(128) + 1) 104 index = rand.Intn(NumberOfPeers) 105 allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) 106 slow6 = slow6.Insert(addr6[:], cidr, peers[index]) 107 } 108 109 var p int 110 for p = 0; ; p++ { 111 for n := 0; n < NumberOfTests; n++ { 112 var addr4 [4]byte 113 rand.Read(addr4[:]) 114 peer1 := slow4.Lookup(addr4[:]) 115 peer2 := allowedIPs.Lookup(addr4[:]) 116 if peer1 != peer2 { 117 t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) 118 } 119 120 var addr6 [16]byte 121 rand.Read(addr6[:]) 122 peer1 = slow6.Lookup(addr6[:]) 123 peer2 = allowedIPs.Lookup(addr6[:]) 124 if peer1 != peer2 { 125 t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) 126 } 127 } 128 if p >= len(peers) || p >= NumberOfPeerRemovals { 129 break 130 } 131 allowedIPs.RemoveByPeer(peers[p]) 132 slow4 = slow4.RemoveByPeer(peers[p]) 133 slow6 = slow6.RemoveByPeer(peers[p]) 134 } 135 for ; p < len(peers); p++ { 136 allowedIPs.RemoveByPeer(peers[p]) 137 } 138 139 if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 140 t.Error("Failed to remove all nodes from trie by peer") 141 } 142 }