github.com/vpnishe/netstack@v1.10.6/tcpip/stack/transport_demuxer_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack_test 16 17 import ( 18 "math" 19 "math/rand" 20 "testing" 21 22 "github.com/vpnishe/netstack/tcpip" 23 "github.com/vpnishe/netstack/tcpip/buffer" 24 "github.com/vpnishe/netstack/tcpip/header" 25 "github.com/vpnishe/netstack/tcpip/link/channel" 26 "github.com/vpnishe/netstack/tcpip/network/ipv4" 27 "github.com/vpnishe/netstack/tcpip/network/ipv6" 28 "github.com/vpnishe/netstack/tcpip/stack" 29 "github.com/vpnishe/netstack/tcpip/transport/udp" 30 "github.com/vpnishe/netstack/waiter" 31 ) 32 33 const ( 34 stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" 35 testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" 36 37 stackAddr = "\x0a\x00\x00\x01" 38 stackPort = 1234 39 testPort = 4096 40 ) 41 42 type testContext struct { 43 t *testing.T 44 linkEPs map[string]*channel.Endpoint 45 s *stack.Stack 46 47 ep tcpip.Endpoint 48 wq waiter.Queue 49 } 50 51 func (c *testContext) cleanup() { 52 if c.ep != nil { 53 c.ep.Close() 54 } 55 } 56 57 func (c *testContext) createV6Endpoint(v6only bool) { 58 var err *tcpip.Error 59 c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) 60 if err != nil { 61 c.t.Fatalf("NewEndpoint failed: %v", err) 62 } 63 64 var v tcpip.V6OnlyOption 65 if v6only { 66 v = 1 67 } 68 if err := c.ep.SetSockOpt(v); err != nil { 69 c.t.Fatalf("SetSockOpt failed: %v", err) 70 } 71 } 72 73 // newDualTestContextMultiNic creates the testing context and also linkEpNames 74 // named NICs. 75 func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { 76 s := stack.New(stack.Options{ 77 NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, 78 TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) 79 linkEPs := make(map[string]*channel.Endpoint) 80 for i, linkEpName := range linkEpNames { 81 channelEP := channel.New(256, mtu, "") 82 nicID := tcpip.NICID(i + 1) 83 if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil { 84 t.Fatalf("CreateNIC failed: %v", err) 85 } 86 linkEPs[linkEpName] = channelEP 87 88 if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { 89 t.Fatalf("AddAddress IPv4 failed: %v", err) 90 } 91 92 if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { 93 t.Fatalf("AddAddress IPv6 failed: %v", err) 94 } 95 } 96 97 s.SetRouteTable([]tcpip.Route{ 98 { 99 Destination: header.IPv4EmptySubnet, 100 NIC: 1, 101 }, 102 { 103 Destination: header.IPv6EmptySubnet, 104 NIC: 1, 105 }, 106 }) 107 108 return &testContext{ 109 t: t, 110 s: s, 111 linkEPs: linkEPs, 112 } 113 } 114 115 type headers struct { 116 srcPort uint16 117 dstPort uint16 118 } 119 120 func newPayload() []byte { 121 b := make([]byte, 30+rand.Intn(100)) 122 for i := range b { 123 b[i] = byte(rand.Intn(256)) 124 } 125 return b 126 } 127 128 func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { 129 // Allocate a buffer for data and headers. 130 buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) 131 copy(buf[len(buf)-len(payload):], payload) 132 133 // Initialize the IP header. 134 ip := header.IPv6(buf) 135 ip.Encode(&header.IPv6Fields{ 136 PayloadLength: uint16(header.UDPMinimumSize + len(payload)), 137 NextHeader: uint8(udp.ProtocolNumber), 138 HopLimit: 65, 139 SrcAddr: testV6Addr, 140 DstAddr: stackV6Addr, 141 }) 142 143 // Initialize the UDP header. 144 u := header.UDP(buf[header.IPv6MinimumSize:]) 145 u.Encode(&header.UDPFields{ 146 SrcPort: h.srcPort, 147 DstPort: h.dstPort, 148 Length: uint16(header.UDPMinimumSize + len(payload)), 149 }) 150 151 // Calculate the UDP pseudo-header checksum. 152 xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) 153 154 // Calculate the UDP checksum and set it. 155 xsum = header.Checksum(payload, xsum) 156 u.SetChecksum(^u.CalculateChecksum(xsum)) 157 158 // Inject packet. 159 c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ 160 Data: buf.ToVectorisedView(), 161 }) 162 } 163 164 func TestTransportDemuxerRegister(t *testing.T) { 165 for _, test := range []struct { 166 name string 167 proto tcpip.NetworkProtocolNumber 168 want *tcpip.Error 169 }{ 170 {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, 171 {"success", ipv4.ProtocolNumber, nil}, 172 } { 173 t.Run(test.name, func(t *testing.T) { 174 s := stack.New(stack.Options{ 175 NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, 176 TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) 177 if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { 178 t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) 179 } 180 }) 181 } 182 } 183 184 // TestReuseBindToDevice injects varied packets on input devices and checks that 185 // the distribution of packets received matches expectations. 186 func TestDistribution(t *testing.T) { 187 type endpointSockopts struct { 188 reuse int 189 bindToDevice string 190 } 191 for _, test := range []struct { 192 name string 193 // endpoints will received the inject packets. 194 endpoints []endpointSockopts 195 // wantedDistribution is the wanted ratio of packets received on each 196 // endpoint for each NIC on which packets are injected. 197 wantedDistributions map[string][]float64 198 }{ 199 { 200 "BindPortReuse", 201 // 5 endpoints that all have reuse set. 202 []endpointSockopts{ 203 endpointSockopts{1, ""}, 204 endpointSockopts{1, ""}, 205 endpointSockopts{1, ""}, 206 endpointSockopts{1, ""}, 207 endpointSockopts{1, ""}, 208 }, 209 map[string][]float64{ 210 // Injected packets on dev0 get distributed evenly. 211 "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, 212 }, 213 }, 214 { 215 "BindToDevice", 216 // 3 endpoints with various bindings. 217 []endpointSockopts{ 218 endpointSockopts{0, "dev0"}, 219 endpointSockopts{0, "dev1"}, 220 endpointSockopts{0, "dev2"}, 221 }, 222 map[string][]float64{ 223 // Injected packets on dev0 go only to the endpoint bound to dev0. 224 "dev0": []float64{1, 0, 0}, 225 // Injected packets on dev1 go only to the endpoint bound to dev1. 226 "dev1": []float64{0, 1, 0}, 227 // Injected packets on dev2 go only to the endpoint bound to dev2. 228 "dev2": []float64{0, 0, 1}, 229 }, 230 }, 231 { 232 "ReuseAndBindToDevice", 233 // 6 endpoints with various bindings. 234 []endpointSockopts{ 235 endpointSockopts{1, "dev0"}, 236 endpointSockopts{1, "dev0"}, 237 endpointSockopts{1, "dev1"}, 238 endpointSockopts{1, "dev1"}, 239 endpointSockopts{1, "dev1"}, 240 endpointSockopts{1, ""}, 241 }, 242 map[string][]float64{ 243 // Injected packets on dev0 get distributed among endpoints bound to 244 // dev0. 245 "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, 246 // Injected packets on dev1 get distributed among endpoints bound to 247 // dev1 or unbound. 248 "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, 249 // Injected packets on dev999 go only to the unbound. 250 "dev999": []float64{0, 0, 0, 0, 0, 1}, 251 }, 252 }, 253 } { 254 t.Run(test.name, func(t *testing.T) { 255 for device, wantedDistribution := range test.wantedDistributions { 256 t.Run(device, func(t *testing.T) { 257 var devices []string 258 for d := range test.wantedDistributions { 259 devices = append(devices, d) 260 } 261 c := newDualTestContextMultiNic(t, defaultMTU, devices) 262 defer c.cleanup() 263 264 c.createV6Endpoint(false) 265 266 eps := make(map[tcpip.Endpoint]int) 267 268 pollChannel := make(chan tcpip.Endpoint) 269 for i, endpoint := range test.endpoints { 270 // Try to receive the data. 271 wq := waiter.Queue{} 272 we, ch := waiter.NewChannelEntry(nil) 273 wq.EventRegister(&we, waiter.EventIn) 274 defer wq.EventUnregister(&we) 275 defer close(ch) 276 277 var err *tcpip.Error 278 ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) 279 if err != nil { 280 c.t.Fatalf("NewEndpoint failed: %v", err) 281 } 282 eps[ep] = i 283 284 go func(ep tcpip.Endpoint) { 285 for range ch { 286 pollChannel <- ep 287 } 288 }(ep) 289 290 defer ep.Close() 291 reusePortOption := tcpip.ReusePortOption(endpoint.reuse) 292 if err := ep.SetSockOpt(reusePortOption); err != nil { 293 c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) 294 } 295 bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) 296 if err := ep.SetSockOpt(bindToDeviceOption); err != nil { 297 c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) 298 } 299 if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { 300 t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) 301 } 302 } 303 304 npackets := 100000 305 nports := 10000 306 if got, want := len(test.endpoints), len(wantedDistribution); got != want { 307 t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) 308 } 309 ports := make(map[uint16]tcpip.Endpoint) 310 stats := make(map[tcpip.Endpoint]int) 311 for i := 0; i < npackets; i++ { 312 // Send a packet. 313 port := uint16(i % nports) 314 payload := newPayload() 315 c.sendV6Packet(payload, 316 &headers{ 317 srcPort: testPort + port, 318 dstPort: stackPort}, 319 device) 320 321 var addr tcpip.FullAddress 322 ep := <-pollChannel 323 _, _, err := ep.Read(&addr) 324 if err != nil { 325 c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) 326 } 327 stats[ep]++ 328 if i < nports { 329 ports[uint16(i)] = ep 330 } else { 331 // Check that all packets from one client are handled by the same 332 // socket. 333 if want, got := ports[port], ep; want != got { 334 t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) 335 } 336 } 337 } 338 339 // Check that a packet distribution is as expected. 340 for ep, i := range eps { 341 wantedRatio := wantedDistribution[i] 342 wantedRecv := wantedRatio * float64(npackets) 343 actualRecv := stats[ep] 344 actualRatio := float64(stats[ep]) / float64(npackets) 345 // The deviation is less than 10%. 346 if math.Abs(actualRatio-wantedRatio) > 0.05 { 347 t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) 348 } 349 } 350 }) 351 } 352 }) 353 } 354 }