github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/stack/transport_demuxer_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack_test 16 17 import ( 18 "io/ioutil" 19 "math" 20 "math/rand" 21 "strconv" 22 "testing" 23 24 "github.com/SagerNet/gvisor/pkg/tcpip" 25 "github.com/SagerNet/gvisor/pkg/tcpip/buffer" 26 "github.com/SagerNet/gvisor/pkg/tcpip/header" 27 "github.com/SagerNet/gvisor/pkg/tcpip/link/channel" 28 "github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4" 29 "github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6" 30 "github.com/SagerNet/gvisor/pkg/tcpip/ports" 31 "github.com/SagerNet/gvisor/pkg/tcpip/stack" 32 "github.com/SagerNet/gvisor/pkg/tcpip/transport/udp" 33 "github.com/SagerNet/gvisor/pkg/waiter" 34 ) 35 36 const ( 37 testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" 38 testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" 39 40 testSrcAddrV4 = "\x0a\x00\x00\x01" 41 testDstAddrV4 = "\x0a\x00\x00\x02" 42 43 testDstPort = 1234 44 testSrcPort = 4096 45 ) 46 47 type testContext struct { 48 linkEps map[tcpip.NICID]*channel.Endpoint 49 s *stack.Stack 50 wq waiter.Queue 51 } 52 53 // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. 54 func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { 55 s := stack.New(stack.Options{ 56 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 57 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 58 }) 59 linkEps := make(map[tcpip.NICID]*channel.Endpoint) 60 for _, linkEpID := range linkEpIDs { 61 channelEp := channel.New(256, mtu, "") 62 if err := s.CreateNIC(linkEpID, channelEp); err != nil { 63 t.Fatalf("CreateNIC failed: %s", err) 64 } 65 linkEps[linkEpID] = channelEp 66 67 if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { 68 t.Fatalf("AddAddress IPv4 failed: %s", err) 69 } 70 71 if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { 72 t.Fatalf("AddAddress IPv6 failed: %s", err) 73 } 74 } 75 76 s.SetRouteTable([]tcpip.Route{ 77 {Destination: header.IPv4EmptySubnet, NIC: 1}, 78 {Destination: header.IPv6EmptySubnet, NIC: 1}, 79 }) 80 81 return &testContext{ 82 s: s, 83 linkEps: linkEps, 84 } 85 } 86 87 type headers struct { 88 srcPort uint16 89 dstPort uint16 90 } 91 92 func newPayload() []byte { 93 b := make([]byte, 30+rand.Intn(100)) 94 for i := range b { 95 b[i] = byte(rand.Intn(256)) 96 } 97 return b 98 } 99 100 func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { 101 buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) 102 payloadStart := len(buf) - len(payload) 103 copy(buf[payloadStart:], payload) 104 105 // Initialize the IP header. 106 ip := header.IPv4(buf) 107 ip.Encode(&header.IPv4Fields{ 108 TOS: 0x80, 109 TotalLength: uint16(len(buf)), 110 TTL: 65, 111 Protocol: uint8(udp.ProtocolNumber), 112 SrcAddr: testSrcAddrV4, 113 DstAddr: testDstAddrV4, 114 }) 115 ip.SetChecksum(^ip.CalculateChecksum()) 116 117 // Initialize the UDP header. 118 u := header.UDP(buf[header.IPv4MinimumSize:]) 119 u.Encode(&header.UDPFields{ 120 SrcPort: h.srcPort, 121 DstPort: h.dstPort, 122 Length: uint16(header.UDPMinimumSize + len(payload)), 123 }) 124 125 // Calculate the UDP pseudo-header checksum. 126 xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u))) 127 128 // Calculate the UDP checksum and set it. 129 xsum = header.Checksum(payload, xsum) 130 u.SetChecksum(^u.CalculateChecksum(xsum)) 131 132 // Inject packet. 133 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 134 Data: buf.ToVectorisedView(), 135 }) 136 c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) 137 } 138 139 func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { 140 // Allocate a buffer for data and headers. 141 buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) 142 copy(buf[len(buf)-len(payload):], payload) 143 144 // Initialize the IP header. 145 ip := header.IPv6(buf) 146 ip.Encode(&header.IPv6Fields{ 147 PayloadLength: uint16(header.UDPMinimumSize + len(payload)), 148 TransportProtocol: udp.ProtocolNumber, 149 HopLimit: 65, 150 SrcAddr: testSrcAddrV6, 151 DstAddr: testDstAddrV6, 152 }) 153 154 // Initialize the UDP header. 155 u := header.UDP(buf[header.IPv6MinimumSize:]) 156 u.Encode(&header.UDPFields{ 157 SrcPort: h.srcPort, 158 DstPort: h.dstPort, 159 Length: uint16(header.UDPMinimumSize + len(payload)), 160 }) 161 162 // Calculate the UDP pseudo-header checksum. 163 xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u))) 164 165 // Calculate the UDP checksum and set it. 166 xsum = header.Checksum(payload, xsum) 167 u.SetChecksum(^u.CalculateChecksum(xsum)) 168 169 // Inject packet. 170 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 171 Data: buf.ToVectorisedView(), 172 }) 173 c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) 174 } 175 176 func TestTransportDemuxerRegister(t *testing.T) { 177 for _, test := range []struct { 178 name string 179 proto tcpip.NetworkProtocolNumber 180 want tcpip.Error 181 }{ 182 {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}}, 183 {"success", ipv4.ProtocolNumber, nil}, 184 } { 185 t.Run(test.name, func(t *testing.T) { 186 s := stack.New(stack.Options{ 187 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 188 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 189 }) 190 var wq waiter.Queue 191 ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 192 if err != nil { 193 t.Fatal(err) 194 } 195 tEP, ok := ep.(stack.TransportEndpoint) 196 if !ok { 197 t.Fatalf("%T does not implement stack.TransportEndpoint", ep) 198 } 199 if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { 200 t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) 201 } 202 }) 203 } 204 } 205 206 func TestTransportDemuxerRegisterMultiple(t *testing.T) { 207 type test struct { 208 flags ports.Flags 209 want tcpip.Error 210 } 211 for _, subtest := range []struct { 212 name string 213 tests []test 214 }{ 215 {"zeroFlags", []test{ 216 {ports.Flags{}, nil}, 217 {ports.Flags{}, &tcpip.ErrPortInUse{}}, 218 }}, 219 {"multibindFlags", []test{ 220 // Allow multiple registrations same TransportEndpointID with multibind flags. 221 {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, 222 {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, 223 // Disallow registration w/same ID for a non-multibindflag. 224 {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}}, 225 }}, 226 } { 227 t.Run(subtest.name, func(t *testing.T) { 228 s := stack.New(stack.Options{ 229 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 230 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 231 }) 232 var eps []tcpip.Endpoint 233 for idx, test := range subtest.tests { 234 var wq waiter.Queue 235 ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 236 if err != nil { 237 t.Fatal(err) 238 } 239 eps = append(eps, ep) 240 tEP, ok := ep.(stack.TransportEndpoint) 241 if !ok { 242 t.Fatalf("%T does not implement stack.TransportEndpoint", ep) 243 } 244 id := stack.TransportEndpointID{LocalPort: 1} 245 if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want { 246 t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want) 247 } 248 } 249 for _, ep := range eps { 250 ep.Close() 251 } 252 }) 253 } 254 } 255 256 // TestBindToDeviceDistribution injects varied packets on input devices and checks that 257 // the distribution of packets received matches expectations. 258 func TestBindToDeviceDistribution(t *testing.T) { 259 type endpointSockopts struct { 260 reuse bool 261 bindToDevice tcpip.NICID 262 } 263 tcs := []struct { 264 name string 265 // endpoints will received the inject packets. 266 endpoints []endpointSockopts 267 // wantDistributions is the want ratio of packets received on each 268 // endpoint for each NIC on which packets are injected. 269 wantDistributions map[tcpip.NICID][]float64 270 }{ 271 { 272 name: "BindPortReuse", 273 // 5 endpoints that all have reuse set. 274 endpoints: []endpointSockopts{ 275 {reuse: true, bindToDevice: 0}, 276 {reuse: true, bindToDevice: 0}, 277 {reuse: true, bindToDevice: 0}, 278 {reuse: true, bindToDevice: 0}, 279 {reuse: true, bindToDevice: 0}, 280 }, 281 wantDistributions: map[tcpip.NICID][]float64{ 282 // Injected packets on dev0 get distributed evenly. 283 1: {0.2, 0.2, 0.2, 0.2, 0.2}, 284 }, 285 }, 286 { 287 name: "BindToDevice", 288 // 3 endpoints with various bindings. 289 endpoints: []endpointSockopts{ 290 {reuse: false, bindToDevice: 1}, 291 {reuse: false, bindToDevice: 2}, 292 {reuse: false, bindToDevice: 3}, 293 }, 294 wantDistributions: map[tcpip.NICID][]float64{ 295 // Injected packets on dev0 go only to the endpoint bound to dev0. 296 1: {1, 0, 0}, 297 // Injected packets on dev1 go only to the endpoint bound to dev1. 298 2: {0, 1, 0}, 299 // Injected packets on dev2 go only to the endpoint bound to dev2. 300 3: {0, 0, 1}, 301 }, 302 }, 303 { 304 name: "ReuseAndBindToDevice", 305 // 6 endpoints with various bindings. 306 endpoints: []endpointSockopts{ 307 {reuse: true, bindToDevice: 1}, 308 {reuse: true, bindToDevice: 1}, 309 {reuse: true, bindToDevice: 2}, 310 {reuse: true, bindToDevice: 2}, 311 {reuse: true, bindToDevice: 2}, 312 {reuse: true, bindToDevice: 0}, 313 }, 314 wantDistributions: map[tcpip.NICID][]float64{ 315 // Injected packets on dev0 get distributed among endpoints bound to 316 // dev0. 317 1: {0.5, 0.5, 0, 0, 0, 0}, 318 // Injected packets on dev1 get distributed among endpoints bound to 319 // dev1 or unbound. 320 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, 321 // Injected packets on dev999 go only to the unbound. 322 1000: {0, 0, 0, 0, 0, 1}, 323 }, 324 }, 325 } 326 protos := map[string]tcpip.NetworkProtocolNumber{ 327 "IPv4": ipv4.ProtocolNumber, 328 "IPv6": ipv6.ProtocolNumber, 329 } 330 331 for _, test := range tcs { 332 for protoName, protoNum := range protos { 333 for device, wantDistribution := range test.wantDistributions { 334 t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) { 335 // Create the NICs. 336 var devices []tcpip.NICID 337 for d := range test.wantDistributions { 338 devices = append(devices, d) 339 } 340 c := newDualTestContextMultiNIC(t, defaultMTU, devices) 341 342 // Create endpoints and bind each to a NIC, sometimes reusing ports. 343 eps := make(map[tcpip.Endpoint]int) 344 pollChannel := make(chan tcpip.Endpoint) 345 for i, endpoint := range test.endpoints { 346 // Try to receive the data. 347 wq := waiter.Queue{} 348 we, ch := waiter.NewChannelEntry(nil) 349 wq.EventRegister(&we, waiter.ReadableEvents) 350 t.Cleanup(func() { 351 wq.EventUnregister(&we) 352 close(ch) 353 }) 354 355 var err tcpip.Error 356 ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq) 357 if err != nil { 358 t.Fatalf("NewEndpoint failed: %s", err) 359 } 360 t.Cleanup(ep.Close) 361 eps[ep] = i 362 363 go func(ep tcpip.Endpoint) { 364 for range ch { 365 pollChannel <- ep 366 } 367 }(ep) 368 369 ep.SocketOptions().SetReusePort(endpoint.reuse) 370 if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { 371 t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) 372 } 373 374 var dstAddr tcpip.Address 375 switch protoNum { 376 case ipv4.ProtocolNumber: 377 dstAddr = testDstAddrV4 378 case ipv6.ProtocolNumber: 379 dstAddr = testDstAddrV6 380 default: 381 t.Fatalf("unexpected protocol number: %d", protoNum) 382 } 383 if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { 384 t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) 385 } 386 } 387 388 // Send packets across a range of ports, checking that packets from 389 // the same source port are always demultiplexed to the same 390 // destination endpoint. 391 npackets := 10_000 392 nports := 1_000 393 if got, want := len(test.endpoints), len(wantDistribution); got != want { 394 t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) 395 } 396 endpoints := make(map[uint16]tcpip.Endpoint) 397 stats := make(map[tcpip.Endpoint]int) 398 for i := 0; i < npackets; i++ { 399 // Send a packet. 400 port := uint16(i % nports) 401 payload := newPayload() 402 hdrs := &headers{ 403 srcPort: testSrcPort + port, 404 dstPort: testDstPort, 405 } 406 switch protoNum { 407 case ipv4.ProtocolNumber: 408 c.sendV4Packet(payload, hdrs, device) 409 case ipv6.ProtocolNumber: 410 c.sendV6Packet(payload, hdrs, device) 411 default: 412 t.Fatalf("unexpected protocol number: %d", protoNum) 413 } 414 415 ep := <-pollChannel 416 if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil { 417 t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) 418 } 419 stats[ep]++ 420 if i < nports { 421 endpoints[uint16(i)] = ep 422 } else { 423 // Check that all packets from one client are handled by the same 424 // socket. 425 if want, got := endpoints[port], ep; want != got { 426 t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) 427 } 428 } 429 } 430 431 // Check that a packet distribution is as expected. 432 for ep, i := range eps { 433 wantRatio := wantDistribution[i] 434 wantRecv := wantRatio * float64(npackets) 435 actualRecv := stats[ep] 436 actualRatio := float64(stats[ep]) / float64(npackets) 437 // The deviation is less than 10%. 438 if math.Abs(actualRatio-wantRatio) > 0.05 { 439 t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets) 440 } 441 } 442 }) 443 } 444 } 445 } 446 }