github.com/FlowerWrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/stack/transport_demuxer_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack_test 16 17 import ( 18 "math" 19 "math/rand" 20 "testing" 21 22 "github.com/FlowerWrong/netstack/tcpip" 23 "github.com/FlowerWrong/netstack/tcpip/buffer" 24 "github.com/FlowerWrong/netstack/tcpip/header" 25 "github.com/FlowerWrong/netstack/tcpip/link/channel" 26 "github.com/FlowerWrong/netstack/tcpip/network/ipv4" 27 "github.com/FlowerWrong/netstack/tcpip/network/ipv6" 28 "github.com/FlowerWrong/netstack/tcpip/stack" 29 "github.com/FlowerWrong/netstack/tcpip/transport/udp" 30 "github.com/FlowerWrong/netstack/waiter" 31 ) 32 33 const ( 34 stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" 35 testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" 36 37 stackAddr = "\x0a\x00\x00\x01" 38 stackPort = 1234 39 testPort = 4096 40 ) 41 42 type testContext struct { 43 t *testing.T 44 linkEPs map[string]*channel.Endpoint 45 s *stack.Stack 46 47 ep tcpip.Endpoint 48 wq waiter.Queue 49 } 50 51 func (c *testContext) cleanup() { 52 if c.ep != nil { 53 c.ep.Close() 54 } 55 } 56 57 func (c *testContext) createV6Endpoint(v6only bool) { 58 var err *tcpip.Error 59 c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) 60 if err != nil { 61 c.t.Fatalf("NewEndpoint failed: %v", err) 62 } 63 64 var v tcpip.V6OnlyOption 65 if v6only { 66 v = 1 67 } 68 if err := c.ep.SetSockOpt(v); err != nil { 69 c.t.Fatalf("SetSockOpt failed: %v", err) 70 } 71 } 72 73 // newDualTestContextMultiNic creates the testing context and also linkEpNames 74 // named NICs. 75 func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { 76 s := stack.New(stack.Options{ 77 NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, 78 TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) 79 linkEPs := make(map[string]*channel.Endpoint) 80 for i, linkEpName := range linkEpNames { 81 channelEP := channel.New(256, mtu, "") 82 nicid := tcpip.NICID(i + 1) 83 if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { 84 t.Fatalf("CreateNIC failed: %v", err) 85 } 86 linkEPs[linkEpName] = channelEP 87 88 if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { 89 t.Fatalf("AddAddress IPv4 failed: %v", err) 90 } 91 92 if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { 93 t.Fatalf("AddAddress IPv6 failed: %v", err) 94 } 95 } 96 97 s.SetRouteTable([]tcpip.Route{ 98 { 99 Destination: header.IPv4EmptySubnet, 100 NIC: 1, 101 }, 102 { 103 Destination: header.IPv6EmptySubnet, 104 NIC: 1, 105 }, 106 }) 107 108 return &testContext{ 109 t: t, 110 s: s, 111 linkEPs: linkEPs, 112 } 113 } 114 115 type headers struct { 116 srcPort uint16 117 dstPort uint16 118 } 119 120 func newPayload() []byte { 121 b := make([]byte, 30+rand.Intn(100)) 122 for i := range b { 123 b[i] = byte(rand.Intn(256)) 124 } 125 return b 126 } 127 128 func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { 129 // Allocate a buffer for data and headers. 130 buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) 131 copy(buf[len(buf)-len(payload):], payload) 132 133 // Initialize the IP header. 134 ip := header.IPv6(buf) 135 ip.Encode(&header.IPv6Fields{ 136 PayloadLength: uint16(header.UDPMinimumSize + len(payload)), 137 NextHeader: uint8(udp.ProtocolNumber), 138 HopLimit: 65, 139 SrcAddr: testV6Addr, 140 DstAddr: stackV6Addr, 141 }) 142 143 // Initialize the UDP header. 144 u := header.UDP(buf[header.IPv6MinimumSize:]) 145 u.Encode(&header.UDPFields{ 146 SrcPort: h.srcPort, 147 DstPort: h.dstPort, 148 Length: uint16(header.UDPMinimumSize + len(payload)), 149 }) 150 151 // Calculate the UDP pseudo-header checksum. 152 xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) 153 154 // Calculate the UDP checksum and set it. 155 xsum = header.Checksum(payload, xsum) 156 u.SetChecksum(^u.CalculateChecksum(xsum)) 157 158 // Inject packet. 159 c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) 160 } 161 162 func TestTransportDemuxerRegister(t *testing.T) { 163 for _, test := range []struct { 164 name string 165 proto tcpip.NetworkProtocolNumber 166 want *tcpip.Error 167 }{ 168 {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, 169 {"success", ipv4.ProtocolNumber, nil}, 170 } { 171 t.Run(test.name, func(t *testing.T) { 172 s := stack.New(stack.Options{ 173 NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, 174 TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) 175 if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { 176 t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) 177 } 178 }) 179 } 180 } 181 182 // TestReuseBindToDevice injects varied packets on input devices and checks that 183 // the distribution of packets received matches expectations. 184 func TestDistribution(t *testing.T) { 185 type endpointSockopts struct { 186 reuse int 187 bindToDevice string 188 } 189 for _, test := range []struct { 190 name string 191 // endpoints will received the inject packets. 192 endpoints []endpointSockopts 193 // wantedDistribution is the wanted ratio of packets received on each 194 // endpoint for each NIC on which packets are injected. 195 wantedDistributions map[string][]float64 196 }{ 197 { 198 "BindPortReuse", 199 // 5 endpoints that all have reuse set. 200 []endpointSockopts{ 201 endpointSockopts{1, ""}, 202 endpointSockopts{1, ""}, 203 endpointSockopts{1, ""}, 204 endpointSockopts{1, ""}, 205 endpointSockopts{1, ""}, 206 }, 207 map[string][]float64{ 208 // Injected packets on dev0 get distributed evenly. 209 "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, 210 }, 211 }, 212 { 213 "BindToDevice", 214 // 3 endpoints with various bindings. 215 []endpointSockopts{ 216 endpointSockopts{0, "dev0"}, 217 endpointSockopts{0, "dev1"}, 218 endpointSockopts{0, "dev2"}, 219 }, 220 map[string][]float64{ 221 // Injected packets on dev0 go only to the endpoint bound to dev0. 222 "dev0": []float64{1, 0, 0}, 223 // Injected packets on dev1 go only to the endpoint bound to dev1. 224 "dev1": []float64{0, 1, 0}, 225 // Injected packets on dev2 go only to the endpoint bound to dev2. 226 "dev2": []float64{0, 0, 1}, 227 }, 228 }, 229 { 230 "ReuseAndBindToDevice", 231 // 6 endpoints with various bindings. 232 []endpointSockopts{ 233 endpointSockopts{1, "dev0"}, 234 endpointSockopts{1, "dev0"}, 235 endpointSockopts{1, "dev1"}, 236 endpointSockopts{1, "dev1"}, 237 endpointSockopts{1, "dev1"}, 238 endpointSockopts{1, ""}, 239 }, 240 map[string][]float64{ 241 // Injected packets on dev0 get distributed among endpoints bound to 242 // dev0. 243 "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, 244 // Injected packets on dev1 get distributed among endpoints bound to 245 // dev1 or unbound. 246 "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, 247 // Injected packets on dev999 go only to the unbound. 248 "dev999": []float64{0, 0, 0, 0, 0, 1}, 249 }, 250 }, 251 } { 252 t.Run(test.name, func(t *testing.T) { 253 for device, wantedDistribution := range test.wantedDistributions { 254 t.Run(device, func(t *testing.T) { 255 var devices []string 256 for d := range test.wantedDistributions { 257 devices = append(devices, d) 258 } 259 c := newDualTestContextMultiNic(t, defaultMTU, devices) 260 defer c.cleanup() 261 262 c.createV6Endpoint(false) 263 264 eps := make(map[tcpip.Endpoint]int) 265 266 pollChannel := make(chan tcpip.Endpoint) 267 for i, endpoint := range test.endpoints { 268 // Try to receive the data. 269 wq := waiter.Queue{} 270 we, ch := waiter.NewChannelEntry(nil) 271 wq.EventRegister(&we, waiter.EventIn) 272 defer wq.EventUnregister(&we) 273 defer close(ch) 274 275 var err *tcpip.Error 276 ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) 277 if err != nil { 278 c.t.Fatalf("NewEndpoint failed: %v", err) 279 } 280 eps[ep] = i 281 282 go func(ep tcpip.Endpoint) { 283 for range ch { 284 pollChannel <- ep 285 } 286 }(ep) 287 288 defer ep.Close() 289 reusePortOption := tcpip.ReusePortOption(endpoint.reuse) 290 if err := ep.SetSockOpt(reusePortOption); err != nil { 291 c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) 292 } 293 bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) 294 if err := ep.SetSockOpt(bindToDeviceOption); err != nil { 295 c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) 296 } 297 if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { 298 t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) 299 } 300 } 301 302 npackets := 100000 303 nports := 10000 304 if got, want := len(test.endpoints), len(wantedDistribution); got != want { 305 t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) 306 } 307 ports := make(map[uint16]tcpip.Endpoint) 308 stats := make(map[tcpip.Endpoint]int) 309 for i := 0; i < npackets; i++ { 310 // Send a packet. 311 port := uint16(i % nports) 312 payload := newPayload() 313 c.sendV6Packet(payload, 314 &headers{ 315 srcPort: testPort + port, 316 dstPort: stackPort}, 317 device) 318 319 var addr tcpip.FullAddress 320 ep := <-pollChannel 321 _, _, err := ep.Read(&addr) 322 if err != nil { 323 c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) 324 } 325 stats[ep]++ 326 if i < nports { 327 ports[uint16(i)] = ep 328 } else { 329 // Check that all packets from one client are handled by the same 330 // socket. 331 if want, got := ports[port], ep; want != got { 332 t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) 333 } 334 } 335 } 336 337 // Check that a packet distribution is as expected. 338 for ep, i := range eps { 339 wantedRatio := wantedDistribution[i] 340 wantedRecv := wantedRatio * float64(npackets) 341 actualRecv := stats[ep] 342 actualRatio := float64(stats[ep]) / float64(npackets) 343 // The deviation is less than 10%. 344 if math.Abs(actualRatio-wantedRatio) > 0.05 { 345 t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) 346 } 347 } 348 }) 349 } 350 }) 351 } 352 }