gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/stack/transport_demuxer_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack_test 16 17 import ( 18 "io/ioutil" 19 "math" 20 "math/rand" 21 "strconv" 22 "testing" 23 24 "gvisor.dev/gvisor/pkg/buffer" 25 "gvisor.dev/gvisor/pkg/tcpip" 26 "gvisor.dev/gvisor/pkg/tcpip/checksum" 27 "gvisor.dev/gvisor/pkg/tcpip/header" 28 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 29 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 30 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 31 "gvisor.dev/gvisor/pkg/tcpip/ports" 32 "gvisor.dev/gvisor/pkg/tcpip/stack" 33 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 34 "gvisor.dev/gvisor/pkg/waiter" 35 ) 36 37 const ( 38 testDstPort = 1234 39 testSrcPort = 4096 40 ) 41 42 var ( 43 testDstAddrV6 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")) 44 testSrcAddrV6 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")) 45 46 testSrcAddrV4 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x01")) 47 testDstAddrV4 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x02")) 48 ) 49 50 type testContext struct { 51 linkEps map[tcpip.NICID]*channel.Endpoint 52 s *stack.Stack 53 wq waiter.Queue 54 } 55 56 // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. 57 func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { 58 s := stack.New(stack.Options{ 59 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 60 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 61 }) 62 linkEps := make(map[tcpip.NICID]*channel.Endpoint) 63 for _, linkEpID := range linkEpIDs { 64 channelEp := channel.New(256, mtu, "") 65 if err := s.CreateNIC(linkEpID, channelEp); err != nil { 66 t.Fatalf("CreateNIC failed: %s", err) 67 } 68 linkEps[linkEpID] = channelEp 69 70 protocolAddrV4 := tcpip.ProtocolAddress{ 71 Protocol: ipv4.ProtocolNumber, 72 AddressWithPrefix: testDstAddrV4.WithPrefix(), 73 } 74 if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil { 75 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err) 76 } 77 78 protocolAddrV6 := tcpip.ProtocolAddress{ 79 Protocol: ipv6.ProtocolNumber, 80 AddressWithPrefix: testDstAddrV6.WithPrefix(), 81 } 82 if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil { 83 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err) 84 } 85 } 86 87 s.SetRouteTable([]tcpip.Route{ 88 {Destination: header.IPv4EmptySubnet, NIC: 1}, 89 {Destination: header.IPv6EmptySubnet, NIC: 1}, 90 }) 91 92 return &testContext{ 93 s: s, 94 linkEps: linkEps, 95 } 96 } 97 98 type headers struct { 99 srcPort uint16 100 dstPort uint16 101 } 102 103 func newPayload() []byte { 104 b := make([]byte, 30+rand.Intn(100)) 105 for i := range b { 106 b[i] = byte(rand.Intn(256)) 107 } 108 return b 109 } 110 111 func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { 112 buf := make([]byte, header.UDPMinimumSize+header.IPv4MinimumSize+len(payload)) 113 payloadStart := len(buf) - len(payload) 114 copy(buf[payloadStart:], payload) 115 116 // Initialize the IP header. 117 ip := header.IPv4(buf) 118 ip.Encode(&header.IPv4Fields{ 119 TOS: 0x80, 120 TotalLength: uint16(len(buf)), 121 TTL: 65, 122 Protocol: uint8(udp.ProtocolNumber), 123 SrcAddr: testSrcAddrV4, 124 DstAddr: testDstAddrV4, 125 }) 126 ip.SetChecksum(^ip.CalculateChecksum()) 127 128 // Initialize the UDP header. 129 u := header.UDP(buf[header.IPv4MinimumSize:]) 130 u.Encode(&header.UDPFields{ 131 SrcPort: h.srcPort, 132 DstPort: h.dstPort, 133 Length: uint16(header.UDPMinimumSize + len(payload)), 134 }) 135 136 // Calculate the UDP pseudo-header checksum. 137 xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u))) 138 139 // Calculate the UDP checksum and set it. 140 xsum = checksum.Checksum(payload, xsum) 141 u.SetChecksum(^u.CalculateChecksum(xsum)) 142 143 // Inject packet. 144 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 145 Payload: buffer.MakeWithData(buf), 146 }) 147 c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) 148 } 149 150 func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { 151 // Allocate a buffer for data and headers. 152 buf := make([]byte, header.UDPMinimumSize+header.IPv6MinimumSize+len(payload)) 153 copy(buf[len(buf)-len(payload):], payload) 154 155 // Initialize the IP header. 156 ip := header.IPv6(buf) 157 ip.Encode(&header.IPv6Fields{ 158 PayloadLength: uint16(header.UDPMinimumSize + len(payload)), 159 TransportProtocol: udp.ProtocolNumber, 160 HopLimit: 65, 161 SrcAddr: testSrcAddrV6, 162 DstAddr: testDstAddrV6, 163 }) 164 165 // Initialize the UDP header. 166 u := header.UDP(buf[header.IPv6MinimumSize:]) 167 u.Encode(&header.UDPFields{ 168 SrcPort: h.srcPort, 169 DstPort: h.dstPort, 170 Length: uint16(header.UDPMinimumSize + len(payload)), 171 }) 172 173 // Calculate the UDP pseudo-header checksum. 174 xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u))) 175 176 // Calculate the UDP checksum and set it. 177 xsum = checksum.Checksum(payload, xsum) 178 u.SetChecksum(^u.CalculateChecksum(xsum)) 179 180 // Inject packet. 181 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 182 Payload: buffer.MakeWithData(buf), 183 }) 184 c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) 185 } 186 187 func TestTransportDemuxerRegister(t *testing.T) { 188 for _, test := range []struct { 189 name string 190 proto tcpip.NetworkProtocolNumber 191 want tcpip.Error 192 }{ 193 {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}}, 194 {"success", ipv4.ProtocolNumber, nil}, 195 } { 196 t.Run(test.name, func(t *testing.T) { 197 s := stack.New(stack.Options{ 198 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 199 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 200 }) 201 var wq waiter.Queue 202 ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 203 if err != nil { 204 t.Fatal(err) 205 } 206 tEP, ok := ep.(stack.TransportEndpoint) 207 if !ok { 208 t.Fatalf("%T does not implement stack.TransportEndpoint", ep) 209 } 210 if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { 211 t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) 212 } 213 }) 214 } 215 } 216 217 func TestTransportDemuxerRegisterMultiple(t *testing.T) { 218 type test struct { 219 flags ports.Flags 220 want tcpip.Error 221 } 222 for _, subtest := range []struct { 223 name string 224 tests []test 225 }{ 226 {"zeroFlags", []test{ 227 {ports.Flags{}, nil}, 228 {ports.Flags{}, &tcpip.ErrPortInUse{}}, 229 }}, 230 {"multibindFlags", []test{ 231 // Allow multiple registrations same TransportEndpointID with multibind flags. 232 {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, 233 {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, 234 // Disallow registration w/same ID for a non-multibindflag. 235 {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}}, 236 }}, 237 } { 238 t.Run(subtest.name, func(t *testing.T) { 239 s := stack.New(stack.Options{ 240 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 241 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 242 }) 243 var eps []tcpip.Endpoint 244 for idx, test := range subtest.tests { 245 var wq waiter.Queue 246 ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 247 if err != nil { 248 t.Fatal(err) 249 } 250 eps = append(eps, ep) 251 tEP, ok := ep.(stack.TransportEndpoint) 252 if !ok { 253 t.Fatalf("%T does not implement stack.TransportEndpoint", ep) 254 } 255 id := stack.TransportEndpointID{LocalPort: 1} 256 if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want { 257 t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want) 258 } 259 } 260 for _, ep := range eps { 261 ep.Close() 262 } 263 }) 264 } 265 } 266 267 // TestBindToDeviceDistribution injects varied packets on input devices and checks that 268 // the distribution of packets received matches expectations. 269 func TestBindToDeviceDistribution(t *testing.T) { 270 type endpointSockopts struct { 271 reuse bool 272 bindToDevice tcpip.NICID 273 } 274 tcs := []struct { 275 name string 276 // endpoints will received the inject packets. 277 endpoints []endpointSockopts 278 // wantDistributions is the want ratio of packets received on each 279 // endpoint for each NIC on which packets are injected. 280 wantDistributions map[tcpip.NICID][]float64 281 }{ 282 { 283 name: "BindPortReuse", 284 // 5 endpoints that all have reuse set. 285 endpoints: []endpointSockopts{ 286 {reuse: true, bindToDevice: 0}, 287 {reuse: true, bindToDevice: 0}, 288 {reuse: true, bindToDevice: 0}, 289 {reuse: true, bindToDevice: 0}, 290 {reuse: true, bindToDevice: 0}, 291 }, 292 wantDistributions: map[tcpip.NICID][]float64{ 293 // Injected packets on dev0 get distributed evenly. 294 1: {0.2, 0.2, 0.2, 0.2, 0.2}, 295 }, 296 }, 297 { 298 name: "BindToDevice", 299 // 3 endpoints with various bindings. 300 endpoints: []endpointSockopts{ 301 {reuse: false, bindToDevice: 1}, 302 {reuse: false, bindToDevice: 2}, 303 {reuse: false, bindToDevice: 3}, 304 }, 305 wantDistributions: map[tcpip.NICID][]float64{ 306 // Injected packets on dev0 go only to the endpoint bound to dev0. 307 1: {1, 0, 0}, 308 // Injected packets on dev1 go only to the endpoint bound to dev1. 309 2: {0, 1, 0}, 310 // Injected packets on dev2 go only to the endpoint bound to dev2. 311 3: {0, 0, 1}, 312 }, 313 }, 314 { 315 name: "ReuseAndBindToDevice", 316 // 6 endpoints with various bindings. 317 endpoints: []endpointSockopts{ 318 {reuse: true, bindToDevice: 1}, 319 {reuse: true, bindToDevice: 1}, 320 {reuse: true, bindToDevice: 2}, 321 {reuse: true, bindToDevice: 2}, 322 {reuse: true, bindToDevice: 2}, 323 {reuse: true, bindToDevice: 0}, 324 }, 325 wantDistributions: map[tcpip.NICID][]float64{ 326 // Injected packets on dev0 get distributed among endpoints bound to 327 // dev0. 328 1: {0.5, 0.5, 0, 0, 0, 0}, 329 // Injected packets on dev1 get distributed among endpoints bound to 330 // dev1 or unbound. 331 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, 332 // Injected packets on dev999 go only to the unbound. 333 1000: {0, 0, 0, 0, 0, 1}, 334 }, 335 }, 336 } 337 protos := map[string]tcpip.NetworkProtocolNumber{ 338 "IPv4": ipv4.ProtocolNumber, 339 "IPv6": ipv6.ProtocolNumber, 340 } 341 342 for _, test := range tcs { 343 for protoName, protoNum := range protos { 344 for device, wantDistribution := range test.wantDistributions { 345 t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) { 346 // Create the NICs. 347 var devices []tcpip.NICID 348 for d := range test.wantDistributions { 349 devices = append(devices, d) 350 } 351 c := newDualTestContextMultiNIC(t, defaultMTU, devices) 352 353 // Create endpoints and bind each to a NIC, sometimes reusing ports. 354 eps := make(map[tcpip.Endpoint]int) 355 pollChannel := make(chan tcpip.Endpoint) 356 for i, endpoint := range test.endpoints { 357 // Try to receive the data. 358 wq := waiter.Queue{} 359 we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) 360 wq.EventRegister(&we) 361 t.Cleanup(func() { 362 wq.EventUnregister(&we) 363 close(ch) 364 }) 365 366 var err tcpip.Error 367 ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq) 368 if err != nil { 369 t.Fatalf("NewEndpoint failed: %s", err) 370 } 371 t.Cleanup(ep.Close) 372 eps[ep] = i 373 374 go func(ep tcpip.Endpoint) { 375 for range ch { 376 pollChannel <- ep 377 } 378 }(ep) 379 380 ep.SocketOptions().SetReusePort(endpoint.reuse) 381 if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { 382 t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) 383 } 384 385 var dstAddr tcpip.Address 386 switch protoNum { 387 case ipv4.ProtocolNumber: 388 dstAddr = testDstAddrV4 389 case ipv6.ProtocolNumber: 390 dstAddr = testDstAddrV6 391 default: 392 t.Fatalf("unexpected protocol number: %d", protoNum) 393 } 394 if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { 395 t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) 396 } 397 } 398 399 // Send packets across a range of ports, checking that packets from 400 // the same source port are always demultiplexed to the same 401 // destination endpoint. 402 npackets := 10_000 403 nports := 1_000 404 if got, want := len(test.endpoints), len(wantDistribution); got != want { 405 t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) 406 } 407 endpoints := make(map[uint16]tcpip.Endpoint) 408 stats := make(map[tcpip.Endpoint]int) 409 for i := 0; i < npackets; i++ { 410 // Send a packet. 411 port := uint16(i % nports) 412 payload := newPayload() 413 hdrs := &headers{ 414 srcPort: testSrcPort + port, 415 dstPort: testDstPort, 416 } 417 switch protoNum { 418 case ipv4.ProtocolNumber: 419 c.sendV4Packet(payload, hdrs, device) 420 case ipv6.ProtocolNumber: 421 c.sendV6Packet(payload, hdrs, device) 422 default: 423 t.Fatalf("unexpected protocol number: %d", protoNum) 424 } 425 426 ep := <-pollChannel 427 if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil { 428 t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) 429 } 430 stats[ep]++ 431 if i < nports { 432 endpoints[uint16(i)] = ep 433 } else { 434 // Check that all packets from one client are handled by the same 435 // socket. 436 if want, got := endpoints[port], ep; want != got { 437 t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) 438 } 439 } 440 } 441 442 // Check that a packet distribution is as expected. 443 for ep, i := range eps { 444 wantRatio := wantDistribution[i] 445 wantRecv := wantRatio * float64(npackets) 446 actualRecv := stats[ep] 447 actualRatio := float64(stats[ep]) / float64(npackets) 448 // The deviation is less than 10%. 449 if math.Abs(actualRatio-wantRatio) > 0.05 { 450 t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets) 451 } 452 } 453 }) 454 } 455 } 456 } 457 }