gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/udp/udp_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package udp_test 16 17 import ( 18 "bytes" 19 "encoding/binary" 20 "fmt" 21 "io/ioutil" 22 "math" 23 "math/rand" 24 "os" 25 "testing" 26 27 "gvisor.dev/gvisor/pkg/buffer" 28 "gvisor.dev/gvisor/pkg/refs" 29 "gvisor.dev/gvisor/pkg/tcpip" 30 "gvisor.dev/gvisor/pkg/tcpip/checker" 31 "gvisor.dev/gvisor/pkg/tcpip/checksum" 32 "gvisor.dev/gvisor/pkg/tcpip/faketime" 33 "gvisor.dev/gvisor/pkg/tcpip/header" 34 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 35 "gvisor.dev/gvisor/pkg/tcpip/link/loopback" 36 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 37 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 38 "gvisor.dev/gvisor/pkg/tcpip/stack" 39 "gvisor.dev/gvisor/pkg/tcpip/testutil" 40 "gvisor.dev/gvisor/pkg/tcpip/transport" 41 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" 42 "gvisor.dev/gvisor/pkg/tcpip/transport/testing/context" 43 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 44 "gvisor.dev/gvisor/pkg/waiter" 45 ) 46 47 const ( 48 testTOS = 0x80 49 testTTL = 65 50 arbitraryPayloadSize = 30 51 ) 52 53 // newRandomPayload returns a payload with the specified size and with 54 // randomized content. 55 func newRandomPayload(size int) []byte { 56 b := make([]byte, size) 57 for i := range b { 58 b[i] = byte(rand.Intn(math.MaxUint8 + 1)) 59 } 60 return b 61 } 62 63 func testRead(c *context.Context, flow context.TestFlow, checkers ...checker.ControlMessagesChecker) { 64 c.T.Helper() 65 66 payload := newRandomPayload(arbitraryPayloadSize) 67 c.InjectPacket(flow.NetProto(), context.BuildUDPPacket(payload, flow, context.Incoming, testTOS, testTTL, false)) 68 c.ReadFromEndpointExpectSuccess(payload, flow, checkers...) 69 } 70 71 func testFailingRead(c *context.Context, flow context.TestFlow, expectReadError bool) { 72 c.T.Helper() 73 74 c.InjectPacket(flow.NetProto(), context.BuildUDPPacket(newRandomPayload(arbitraryPayloadSize), flow, context.Incoming, testTOS, testTTL, false)) 75 if expectReadError { 76 c.ReadFromEndpointExpectError() 77 } else { 78 c.ReadFromEndpointExpectNoPacket() 79 } 80 } 81 82 func TestBindToDeviceOption(t *testing.T) { 83 s := stack.New(stack.Options{ 84 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 85 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 86 Clock: &faketime.NullClock{}, 87 }) 88 defer s.Destroy() 89 90 ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) 91 if err != nil { 92 t.Fatalf("NewEndpoint failed; %s", err) 93 } 94 defer ep.Close() 95 96 opts := stack.NICOptions{Name: "my_device"} 97 if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { 98 t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err) 99 } 100 101 // nicIDPtr is used instead of taking the address of NICID literals, which is 102 // a compiler error. 103 nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { 104 return &s 105 } 106 107 testActions := []struct { 108 name string 109 setBindToDevice *tcpip.NICID 110 setBindToDeviceError tcpip.Error 111 getBindToDevice int32 112 }{ 113 {"GetDefaultValue", nil, nil, 0}, 114 {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, 115 {"BindToExistent", nicIDPtr(321), nil, 321}, 116 {"UnbindToDevice", nicIDPtr(0), nil, 0}, 117 } 118 for _, testAction := range testActions { 119 t.Run(testAction.name, func(t *testing.T) { 120 if testAction.setBindToDevice != nil { 121 bindToDevice := int32(*testAction.setBindToDevice) 122 if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { 123 t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) 124 } 125 } 126 bindToDevice := ep.SocketOptions().GetBindToDevice() 127 if bindToDevice != testAction.getBindToDevice { 128 t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) 129 } 130 }) 131 } 132 } 133 134 func TestBindEphemeralPort(t *testing.T) { 135 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 136 defer c.Cleanup() 137 138 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 139 140 if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { 141 t.Fatalf("ep.Bind(...) failed: %s", err) 142 } 143 } 144 145 func TestBindReservedPort(t *testing.T) { 146 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 147 defer c.Cleanup() 148 149 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 150 151 if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil { 152 c.T.Fatalf("Connect failed: %s", err) 153 } 154 155 addr, err := c.EP.GetLocalAddress() 156 if err != nil { 157 t.Fatalf("GetLocalAddress failed: %s", err) 158 } 159 160 // We can't bind the address reserved by the connected endpoint above. 161 { 162 ep, err := c.Stack.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) 163 if err != nil { 164 t.Fatalf("NewEndpoint failed: %s", err) 165 } 166 defer ep.Close() 167 { 168 err := ep.Bind(addr) 169 if _, ok := err.(*tcpip.ErrPortInUse); !ok { 170 t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) 171 } 172 } 173 } 174 175 func() { 176 ep, err := c.Stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 177 if err != nil { 178 t.Fatalf("NewEndpoint failed: %s", err) 179 } 180 defer ep.Close() 181 // We can't bind ipv4-any on the port reserved by the connected endpoint 182 // above, since the endpoint is dual-stack. 183 { 184 err := ep.Bind(tcpip.FullAddress{Port: addr.Port}) 185 if _, ok := err.(*tcpip.ErrPortInUse); !ok { 186 t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) 187 } 188 } 189 // We can bind an ipv4 address on this port, though. 190 if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: addr.Port}); err != nil { 191 t.Fatalf("ep.Bind(...) failed: %s", err) 192 } 193 }() 194 195 // Once the connected endpoint releases its port reservation, we are able to 196 // bind ipv4-any once again. 197 c.EP.Close() 198 func() { 199 ep, err := c.Stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) 200 if err != nil { 201 t.Fatalf("NewEndpoint failed: %s", err) 202 } 203 defer ep.Close() 204 if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { 205 t.Fatalf("ep.Bind(...) failed: %s", err) 206 } 207 }() 208 } 209 210 func TestV4ReadOnV6(t *testing.T) { 211 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 212 defer c.Cleanup() 213 214 c.CreateEndpointForFlow(context.UnicastV4in6, udp.ProtocolNumber) 215 216 // Bind to wildcard. 217 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 218 c.T.Fatalf("Bind failed: %s", err) 219 } 220 221 payload := newRandomPayload(arbitraryPayloadSize) 222 buf := context.BuildUDPPacket(payload, context.UnicastV4in6, context.Incoming, testTOS, testTTL, false) 223 c.InjectPacket(header.IPv4ProtocolNumber, buf) 224 c.ReadFromEndpointExpectSuccess(payload, context.UnicastV4in6) 225 } 226 227 func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { 228 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 229 defer c.Cleanup() 230 231 c.CreateEndpointForFlow(context.UnicastV4in6, udp.ProtocolNumber) 232 233 // Bind to v4 mapped wildcard. 234 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { 235 c.T.Fatalf("Bind failed: %s", err) 236 } 237 238 // Test acceptance. 239 testRead(c, context.UnicastV4in6) 240 } 241 242 func TestV4ReadOnBoundToV4Mapped(t *testing.T) { 243 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 244 defer c.Cleanup() 245 246 c.CreateEndpointForFlow(context.UnicastV4in6, udp.ProtocolNumber) 247 248 // Bind to local address. 249 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { 250 c.T.Fatalf("Bind failed: %s", err) 251 } 252 253 // Test acceptance. 254 testRead(c, context.UnicastV4in6) 255 } 256 257 func TestV6ReadOnV6(t *testing.T) { 258 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 259 defer c.Cleanup() 260 261 c.CreateEndpointForFlow(context.UnicastV6, udp.ProtocolNumber) 262 263 // Bind to wildcard. 264 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 265 c.T.Fatalf("Bind failed: %s", err) 266 } 267 268 // Test acceptance. 269 testRead(c, context.UnicastV6) 270 } 271 272 // TestV4ReadSelfSource checks that packets coming from a local IP address are 273 // correctly dropped when handleLocal is true and not otherwise. 274 func TestV4ReadSelfSource(t *testing.T) { 275 for _, tt := range []struct { 276 name string 277 handleLocal bool 278 wantErr tcpip.Error 279 wantInvalidSource uint64 280 }{ 281 {"HandleLocal", false, nil, 0}, 282 {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, 283 } { 284 t.Run(tt.name, func(t *testing.T) { 285 c := context.NewWithOptions(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, context.Options{ 286 MTU: context.DefaultMTU, 287 HandleLocal: tt.handleLocal, 288 }) 289 defer c.Cleanup() 290 291 c.CreateEndpointForFlow(context.UnicastV4, udp.ProtocolNumber) 292 293 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 294 t.Fatalf("Bind failed: %s", err) 295 } 296 297 payload := newRandomPayload(arbitraryPayloadSize) 298 h := context.UnicastV4.MakeHeader4Tuple(context.Incoming) 299 h.Src = h.Dst 300 c.InjectPacket(header.IPv4ProtocolNumber, context.BuildV4UDPPacket(payload, h, testTOS, testTTL, false)) 301 302 if got := c.Stack.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { 303 t.Errorf("c.Stack.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) 304 } 305 306 if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr { 307 t.Errorf("got c.EP.Read = %s, want = %s", err, tt.wantErr) 308 } 309 }) 310 } 311 } 312 313 func TestV4ReadOnV4(t *testing.T) { 314 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 315 defer c.Cleanup() 316 317 c.CreateEndpointForFlow(context.UnicastV4, udp.ProtocolNumber) 318 319 // Bind to wildcard. 320 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 321 c.T.Fatalf("Bind failed: %s", err) 322 } 323 324 // Test acceptance. 325 testRead(c, context.UnicastV4) 326 } 327 328 // TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast 329 // address and receive data sent to that address. 330 func TestReadOnBoundToMulticast(t *testing.T) { 331 // FIXME(b/128189410): context.MulticastV4in6 currently doesn't work as 332 // AddMembershipOption doesn't handle V4in6 addresses. 333 for _, flow := range []context.TestFlow{context.MulticastV4, context.MulticastV6, context.MulticastV6Only} { 334 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 335 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 336 defer c.Cleanup() 337 338 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 339 340 // Bind to multicast address. 341 mcastAddr := flow.MapAddrIfApplicable(flow.GetMulticastAddr()) 342 if err := c.EP.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: context.StackPort}); err != nil { 343 c.T.Fatal("Bind failed:", err) 344 } 345 346 // Join multicast group. 347 ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} 348 if err := c.EP.SetSockOpt(&ifoptSet); err != nil { 349 c.T.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) 350 } 351 352 // Check that we receive multicast packets but not unicast or broadcast 353 // ones. 354 testRead(c, flow) 355 testFailingRead(c, context.Broadcast, false /* expectReadError */) 356 testFailingRead(c, context.UnicastV4, false /* expectReadError */) 357 }) 358 } 359 } 360 361 // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast 362 // address and can receive only broadcast data. 363 func TestV4ReadOnBoundToBroadcast(t *testing.T) { 364 for _, flow := range []context.TestFlow{context.Broadcast, context.BroadcastIn6} { 365 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 366 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 367 defer c.Cleanup() 368 369 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 370 371 // Bind to broadcast address. 372 broadcastAddr := flow.MapAddrIfApplicable(context.BroadcastAddr) 373 if err := c.EP.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: context.StackPort}); err != nil { 374 c.T.Fatalf("Bind failed: %s", err) 375 } 376 377 // Check that we receive broadcast packets but not unicast ones. 378 testRead(c, flow) 379 testFailingRead(c, context.UnicastV4, false /* expectReadError */) 380 }) 381 } 382 } 383 384 // TestReadFromMulticast checks that an endpoint will NOT receive a packet 385 // that was sent with multicast SOURCE address. 386 func TestReadFromMulticast(t *testing.T) { 387 for _, flow := range []context.TestFlow{context.ReverseMulticastV4, context.ReverseMulticastV6} { 388 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 389 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 390 defer c.Cleanup() 391 392 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 393 394 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 395 t.Fatalf("Bind failed: %s", err) 396 } 397 testFailingRead(c, flow, false /* expectReadError */) 398 }) 399 } 400 } 401 402 // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY 403 // and receive broadcast and unicast data. 404 func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { 405 for _, flow := range []context.TestFlow{context.Broadcast, context.BroadcastIn6} { 406 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 407 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 408 defer c.Cleanup() 409 410 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 411 412 // Bind to wildcard. 413 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 414 c.T.Fatalf("Bind failed: %s (", err) 415 } 416 417 // Check that we receive both broadcast and unicast packets. 418 testRead(c, flow) 419 testRead(c, context.UnicastV4) 420 }) 421 } 422 } 423 424 func getEndpointWithPreflight(c *context.Context) tcpip.EndpointWithPreflight { 425 epWithPreflight, ok := c.EP.(tcpip.EndpointWithPreflight) 426 427 if !ok { 428 c.T.Fatalf("expect endpoint implements tcpip.EndpointWithPreflight; found endpoint with type %T does not", c.EP) 429 } 430 return epWithPreflight 431 } 432 433 func getWriteOptionsForFlow(flow context.TestFlow) tcpip.WriteOptions { 434 h := flow.MakeHeader4Tuple(context.Outgoing) 435 writeDstAddr := flow.MapAddrIfApplicable(h.Dst.Addr) 436 return tcpip.WriteOptions{ 437 To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.Dst.Port}, 438 } 439 } 440 441 // testWriteFails calls the endpoint's Write method with a packet of the 442 // given test flow, verifying that the method fails with the provided error 443 // code. 444 // TODO(https://gvisor.dev/issue/5623): Extract the test write methods in the 445 // testing context. 446 func testWriteFails(c *context.Context, flow context.TestFlow, payloadSize int, wantErr tcpip.Error) { 447 c.T.Helper() 448 // Take a snapshot of the stats to validate them at the end of the test. 449 var epstats tcpip.TransportEndpointStats 450 c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&epstats) 451 452 var r bytes.Reader 453 r.Reset(newRandomPayload(payloadSize)) 454 _, gotErr := c.EP.Write(&r, getWriteOptionsForFlow(flow)) 455 c.CheckEndpointWriteStats(1, &epstats, gotErr) 456 if gotErr != wantErr { 457 c.T.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) 458 } 459 } 460 461 // testPreflightSucceeds calls the endpoint's Preflight method with a 462 // destination of the given flow, verifying that it succeeds. 463 func testPreflightSucceeds(c *context.Context, flow context.TestFlow) { 464 c.T.Helper() 465 testPreflightImpl(c, flow, true, nil) 466 } 467 468 // testPreflightFails calls the endpoint's Preflight method with a destination 469 // of the given flow, verifying that it fails with the provided error. 470 func testPreflightFails(c *context.Context, flow context.TestFlow, wantErr tcpip.Error) { 471 c.T.Helper() 472 testPreflightImpl(c, flow, true, wantErr) 473 } 474 475 func testPreflightImpl(c *context.Context, flow context.TestFlow, setDest bool, wantErr tcpip.Error) { 476 c.T.Helper() 477 // Take a snapshot of the stats to validate them at the end of the test. 478 var epstats tcpip.TransportEndpointStats 479 c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&epstats) 480 481 writeOpts := tcpip.WriteOptions{} 482 if setDest { 483 writeOpts = getWriteOptionsForFlow(flow) 484 } 485 486 gotErr := getEndpointWithPreflight(c).Preflight(writeOpts) 487 if gotErr != wantErr { 488 c.T.Fatalf("Preflight returned unexpected error: got %v, want %v", gotErr, wantErr) 489 } 490 491 c.CheckEndpointWriteStats(0, &epstats, gotErr) 492 } 493 494 type writeOperation int 495 496 const ( 497 write writeOperation = iota 498 preflight 499 ) 500 501 // testWriteOpSequenceSucceeds calls the provided sequence of write operations with a packet of the 502 // given test flow, verifying that each operation succeeds. 503 func testWriteOpSequenceSucceeds(c *context.Context, flow context.TestFlow, ops []writeOperation, checkers ...checker.NetworkChecker) { 504 c.T.Helper() 505 for _, op := range ops { 506 switch op { 507 case write: 508 testWriteSucceedsAndGetReceivedSrcPort(c, flow, checkers...) 509 case preflight: 510 testPreflightSucceeds(c, flow) 511 } 512 } 513 } 514 515 // testWriteOpSequenceSucceedsNoDestination calls the provided sequence of write operations with a 516 // packet of the given test flow, without giving a destination address:port, verifying that each 517 // operation succeeds. 518 func testWriteOpSequenceSucceedsNoDestination(c *context.Context, flow context.TestFlow, ops []writeOperation) { 519 c.T.Helper() 520 for _, op := range ops { 521 switch op { 522 case write: 523 testWriteAndVerifyInternal(c, flow, false /* setDest */) 524 case preflight: 525 testPreflightImpl(c, flow, false /* setDest */, nil /* wantErr */) 526 } 527 } 528 } 529 530 // testWriteOpSequenceFails calls the provided sequence of write operations with a packet of the 531 // given test flow, verifying that each operation fails with the provided err. 532 func testWriteOpSequenceFails(c *context.Context, flow context.TestFlow, ops []writeOperation, err tcpip.Error) { 533 c.T.Helper() 534 for _, op := range ops { 535 switch op { 536 case write: 537 testWriteFails(c, flow, arbitraryPayloadSize, err) 538 case preflight: 539 testPreflightFails(c, flow, err) 540 } 541 } 542 } 543 544 // testWriteSucceedsAndGetReceivedSrcPort calls the endpoint's Write method with a packet of the 545 // given test flow and a destination constructed from the flow's destination address:port. It then 546 // receives the packet from the link endpoint and verifies its correctness using the 547 // provided checker functions, returning the found source port. 548 // TODO(https://gvisor.dev/issue/5623): Extract the test write methods in the 549 // testing context. 550 func testWriteSucceedsAndGetReceivedSrcPort(c *context.Context, flow context.TestFlow, checkers ...checker.NetworkChecker) uint16 { 551 c.T.Helper() 552 return testWriteAndVerifyInternal(c, flow, true, checkers...) 553 } 554 555 // TODO(https://gvisor.dev/issue/5623): Extract the test write methods in the 556 // testing context. 557 func testWriteNoVerify(c *context.Context, flow context.TestFlow, setDest bool) []byte { 558 c.T.Helper() 559 // Take a snapshot of the stats to validate them at the end of the test. 560 var epstats tcpip.TransportEndpointStats 561 c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&epstats) 562 563 writeOpts := tcpip.WriteOptions{} 564 if setDest { 565 h := flow.MakeHeader4Tuple(context.Outgoing) 566 writeDstAddr := flow.MapAddrIfApplicable(h.Dst.Addr) 567 writeOpts = tcpip.WriteOptions{ 568 To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.Dst.Port}, 569 } 570 } 571 572 var r bytes.Reader 573 payload := newRandomPayload(arbitraryPayloadSize) 574 r.Reset(payload) 575 n, err := c.EP.Write(&r, writeOpts) 576 if err != nil { 577 c.T.Fatalf("Write failed: %s", err) 578 } 579 if n != int64(len(payload)) { 580 c.T.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) 581 } 582 c.CheckEndpointWriteStats(1, &epstats, err) 583 return payload 584 } 585 586 // TODO(https://gvisor.dev/issue/5623): Extract the test write methods in the 587 // testing context. 588 func testWriteAndVerifyInternal(c *context.Context, flow context.TestFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { 589 c.T.Helper() 590 payload := testWriteNoVerify(c, flow, setDest) 591 // Received the packet and check the payload. 592 593 p := c.LinkEP.Read() 594 if p == nil { 595 c.T.Fatalf("Packet wasn't written out") 596 } 597 defer p.DecRef() 598 599 if got, want := p.NetworkProtocolNumber, flow.NetProto(); got != want { 600 c.T.Fatalf("got p.NetworkProtocolNumber = %d, want = %d", got, want) 601 } 602 603 if got, want := p.TransportProtocolNumber, header.UDPProtocolNumber; got != want { 604 c.T.Errorf("got p.TransportProtocolNumber = %d, want = %d", got, want) 605 } 606 607 v := p.ToView() 608 defer v.Release() 609 610 h := flow.MakeHeader4Tuple(context.Outgoing) 611 checkers = append( 612 checkers, 613 checker.SrcAddr(h.Src.Addr), 614 checker.DstAddr(h.Dst.Addr), 615 checker.UDP(checker.DstPort(h.Dst.Port)), 616 ) 617 flow.CheckerFn()(c.T, v, checkers...) 618 619 var udpH header.UDP 620 if flow.IsV4() { 621 udpH = header.IPv4(v.AsSlice()).Payload() 622 } else { 623 udpH = header.IPv6(v.AsSlice()).Payload() 624 } 625 if !bytes.Equal(payload, udpH.Payload()) { 626 c.T.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload) 627 } 628 629 return udpH.SourcePort() 630 } 631 632 func testDualWrite(c *context.Context) uint16 { 633 c.T.Helper() 634 635 v4Port := testWriteSucceedsAndGetReceivedSrcPort(c, context.UnicastV4in6) 636 v6Port := testWriteSucceedsAndGetReceivedSrcPort(c, context.UnicastV6) 637 if v4Port != v6Port { 638 c.T.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) 639 } 640 641 return v4Port 642 } 643 644 func TestDualWriteUnbound(t *testing.T) { 645 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 646 defer c.Cleanup() 647 648 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 649 650 testDualWrite(c) 651 } 652 653 func TestDualWriteBoundToWildcard(t *testing.T) { 654 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 655 defer c.Cleanup() 656 657 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 658 659 // Bind to wildcard. 660 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 661 c.T.Fatalf("Bind failed: %s", err) 662 } 663 664 p := testDualWrite(c) 665 if p != context.StackPort { 666 c.T.Fatalf("Bad port: got %v, want %v", p, context.StackPort) 667 } 668 } 669 670 func TestDualWriteConnectedToV6(t *testing.T) { 671 for _, testCase := range []struct { 672 writeOpSequence []writeOperation 673 expectedNoRouteErrCount uint64 674 }{ 675 {writeOpSequence: []writeOperation{write}, expectedNoRouteErrCount: 1}, 676 {writeOpSequence: []writeOperation{preflight}, expectedNoRouteErrCount: 0}, 677 {writeOpSequence: []writeOperation{preflight, write}, expectedNoRouteErrCount: 1}, 678 } { 679 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 680 681 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 682 683 // Connect to v6 address. 684 if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil { 685 c.T.Fatalf("Bind failed: %s", err) 686 } 687 688 testWriteOpSequenceSucceeds(c, context.UnicastV6, testCase.writeOpSequence) 689 690 // Write to V4 mapped address. 691 testWriteOpSequenceFails(c, context.UnicastV4in6, testCase.writeOpSequence, &tcpip.ErrNetworkUnreachable{}) 692 693 if got := c.EP.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != testCase.expectedNoRouteErrCount { 694 c.T.Fatalf("Endpoint stat not updated. got %d want %d", got, testCase.expectedNoRouteErrCount) 695 } 696 c.Cleanup() 697 } 698 } 699 700 var writeOpSequences = map[string]([]writeOperation){ 701 "write": []writeOperation{write}, 702 "preflight": []writeOperation{preflight}, 703 "write|preflight": []writeOperation{preflight, write}, 704 } 705 706 func TestDualWriteConnectedToV4Mapped(t *testing.T) { 707 for name, writeOpSequence := range writeOpSequences { 708 t.Run(name, func(t *testing.T) { 709 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 710 defer c.Cleanup() 711 712 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 713 714 // Connect to v4 mapped address. 715 if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}); err != nil { 716 c.T.Fatalf("Bind failed: %s", err) 717 } 718 719 testWriteOpSequenceSucceeds(c, context.UnicastV4in6, writeOpSequence) 720 721 // Write to v6 address. 722 testWriteOpSequenceFails(c, context.UnicastV6, writeOpSequence, &tcpip.ErrInvalidEndpointState{}) 723 }) 724 } 725 } 726 727 func TestPreflightBindsEndpoint(t *testing.T) { 728 tcs := []struct { 729 name string 730 proto tcpip.NetworkProtocolNumber 731 flow context.TestFlow 732 }{ 733 { 734 name: "ipv4", 735 proto: ipv4.ProtocolNumber, 736 flow: context.UnicastV4, 737 }, 738 { 739 name: "ipv6", 740 proto: ipv6.ProtocolNumber, 741 flow: context.UnicastV6, 742 }, 743 } 744 for _, tc := range tcs { 745 t.Run(tc.name, func(t *testing.T) { 746 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol}) 747 defer c.Cleanup() 748 749 c.CreateEndpoint(tc.proto, udp.ProtocolNumber) 750 751 h := tc.flow.MakeHeader4Tuple(context.Outgoing) 752 writeDstAddr := tc.flow.MapAddrIfApplicable(h.Dst.Addr) 753 writeOpts := tcpip.WriteOptions{ 754 To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.Dst.Port}, 755 } 756 757 if err := getEndpointWithPreflight(c).Preflight(writeOpts); err != nil { 758 c.T.Fatalf("Preflight failed: %s", err) 759 } 760 761 if c.EP.State() != uint32(transport.DatagramEndpointStateBound) { 762 c.T.Fatalf("Expect UDP endpoint in state %d, found %d", transport.DatagramEndpointStateBound, c.EP.State()) 763 } 764 }) 765 } 766 } 767 768 func TestV4WriteOnV6Only(t *testing.T) { 769 for name, writeOpSequence := range writeOpSequences { 770 t.Run(name, func(t *testing.T) { 771 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 772 defer c.Cleanup() 773 774 c.CreateEndpointForFlow(context.UnicastV6Only, udp.ProtocolNumber) 775 776 // Write to V4 mapped address. 777 testWriteOpSequenceFails(c, context.UnicastV4in6, writeOpSequence, &tcpip.ErrHostUnreachable{}) 778 }) 779 } 780 } 781 782 func TestV6WriteOnBoundToV4Mapped(t *testing.T) { 783 for name, writeOpSequence := range writeOpSequences { 784 t.Run(name, func(t *testing.T) { 785 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 786 defer c.Cleanup() 787 788 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 789 790 // Bind to v4 mapped address. 791 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { 792 c.T.Fatalf("Bind failed: %s", err) 793 } 794 795 // Write to v6 address. 796 testWriteOpSequenceFails(c, context.UnicastV6, writeOpSequence, &tcpip.ErrInvalidEndpointState{}) 797 }) 798 } 799 } 800 801 func TestV6WriteOnConnected(t *testing.T) { 802 for name, writeOpSequence := range writeOpSequences { 803 t.Run(name, func(t *testing.T) { 804 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 805 defer c.Cleanup() 806 807 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 808 809 // Connect to v6 address. 810 if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil { 811 c.T.Fatalf("Connect failed: %s", err) 812 } 813 814 testWriteOpSequenceSucceedsNoDestination(c, context.UnicastV6, writeOpSequence) 815 }) 816 } 817 } 818 819 func TestV4WriteOnConnected(t *testing.T) { 820 for name, writeOpSequence := range writeOpSequences { 821 t.Run(name, func(t *testing.T) { 822 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 823 defer c.Cleanup() 824 825 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 826 827 // Connect to v4 mapped address. 828 if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}); err != nil { 829 c.T.Fatalf("Connect failed: %s", err) 830 } 831 832 testWriteOpSequenceSucceedsNoDestination(c, context.UnicastV4, writeOpSequence) 833 }) 834 } 835 } 836 837 func TestWriteOnConnectedInvalidPort(t *testing.T) { 838 const invalidPort = 8192 839 protocols := map[string]tcpip.NetworkProtocolNumber{ 840 "ipv4": ipv4.ProtocolNumber, 841 "ipv6": ipv6.ProtocolNumber, 842 } 843 for name, proto := range protocols { 844 t.Run(name, func(t *testing.T) { 845 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 846 defer c.Cleanup() 847 848 c.CreateEndpoint(proto, udp.ProtocolNumber) 849 if err := c.EP.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: invalidPort}); err != nil { 850 c.T.Fatalf("Connect failed: %s", err) 851 } 852 writeOpts := tcpip.WriteOptions{ 853 To: &tcpip.FullAddress{Addr: context.StackAddr, Port: invalidPort}, 854 } 855 var r bytes.Reader 856 payload := newRandomPayload(arbitraryPayloadSize) 857 r.Reset(payload) 858 n, err := c.EP.Write(&r, writeOpts) 859 if err != nil { 860 c.T.Fatalf("c.EP.Write(...) = %s, want nil", err) 861 } 862 if got, want := n, int64(len(payload)); got != want { 863 c.T.Fatalf("c.EP.Write(...) wrote %d bytes, want %d bytes", got, want) 864 } 865 866 { 867 err := c.EP.LastError() 868 if _, ok := err.(*tcpip.ErrConnectionRefused); !ok { 869 c.T.Fatalf("expected c.EP.LastError() == ErrConnectionRefused, got: %+v", err) 870 } 871 } 872 }) 873 } 874 } 875 876 // TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket 877 // that is bound to a V4 multicast address. 878 func TestWriteOnBoundToV4Multicast(t *testing.T) { 879 for _, writeOpSequence := range writeOpSequences { 880 for _, flow := range []context.TestFlow{context.UnicastV4, context.MulticastV4, context.Broadcast} { 881 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 882 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 883 defer c.Cleanup() 884 885 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 886 887 // Bind to V4 mcast address. 888 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.MulticastAddr, Port: context.StackPort}); err != nil { 889 c.T.Fatal("Bind failed:", err) 890 } 891 892 testWriteOpSequenceSucceeds(c, flow, writeOpSequence) 893 }) 894 } 895 } 896 } 897 898 // TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a 899 // socket that is bound to a V4-mapped multicast address. 900 func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { 901 for _, writeOpSequence := range writeOpSequences { 902 for _, flow := range []context.TestFlow{context.UnicastV4in6, context.MulticastV4in6, context.BroadcastIn6} { 903 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 904 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 905 defer c.Cleanup() 906 907 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 908 909 // Bind to V4Mapped mcast address. 910 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.MulticastV4MappedAddr, Port: context.StackPort}); err != nil { 911 c.T.Fatalf("Bind failed: %s", err) 912 } 913 914 testWriteOpSequenceSucceeds(c, flow, writeOpSequence) 915 }) 916 } 917 } 918 } 919 920 // TestWriteOnBoundToV6Multicast checks that we can send packets out of a 921 // socket that is bound to a V6 multicast address. 922 func TestWriteOnBoundToV6Multicast(t *testing.T) { 923 for _, writeOpSequence := range writeOpSequences { 924 for _, flow := range []context.TestFlow{context.UnicastV6, context.MulticastV6} { 925 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 926 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 927 defer c.Cleanup() 928 929 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 930 931 // Bind to V6 mcast address. 932 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.MulticastV6Addr, Port: context.StackPort}); err != nil { 933 c.T.Fatalf("Bind failed: %s", err) 934 } 935 936 testWriteOpSequenceSucceeds(c, flow, writeOpSequence) 937 }) 938 } 939 } 940 } 941 942 // TestWriteOnBoundToV6Multicast checks that we can send packets out of a 943 // V6-only socket that is bound to a V6 multicast address. 944 func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { 945 for _, writeOpSequence := range writeOpSequences { 946 for _, flow := range []context.TestFlow{context.UnicastV6Only, context.MulticastV6Only} { 947 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 948 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 949 defer c.Cleanup() 950 951 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 952 953 // Bind to V6 mcast address. 954 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.MulticastV6Addr, Port: context.StackPort}); err != nil { 955 c.T.Fatalf("Bind failed: %s", err) 956 } 957 958 testWriteOpSequenceSucceeds(c, flow, writeOpSequence) 959 }) 960 } 961 } 962 } 963 964 // TestWriteOnBoundToBroadcast checks that we can send packets out of a 965 // socket that is bound to the broadcast address. 966 func TestWriteOnBoundToBroadcast(t *testing.T) { 967 for _, writeOpSequence := range writeOpSequences { 968 for _, flow := range []context.TestFlow{context.UnicastV4, context.MulticastV4, context.Broadcast} { 969 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 970 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 971 defer c.Cleanup() 972 973 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 974 975 // Bind to V4 broadcast address. 976 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.BroadcastAddr, Port: context.StackPort}); err != nil { 977 c.T.Fatal("Bind failed:", err) 978 } 979 980 testWriteOpSequenceSucceeds(c, flow, writeOpSequence) 981 }) 982 } 983 } 984 } 985 986 // TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a 987 // socket that is bound to the V4-mapped broadcast address. 988 func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { 989 for _, writeOpSequence := range writeOpSequences { 990 for _, flow := range []context.TestFlow{context.UnicastV4in6, context.MulticastV4in6, context.BroadcastIn6} { 991 t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { 992 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 993 defer c.Cleanup() 994 995 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 996 997 // Bind to V4Mapped mcast address. 998 if err := c.EP.Bind(tcpip.FullAddress{Addr: context.BroadcastV4MappedAddr, Port: context.StackPort}); err != nil { 999 c.T.Fatalf("Bind failed: %s", err) 1000 } 1001 1002 testWriteOpSequenceSucceeds(c, flow, writeOpSequence) 1003 }) 1004 } 1005 } 1006 } 1007 1008 func TestReadIncrementsPacketsReceived(t *testing.T) { 1009 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1010 defer c.Cleanup() 1011 1012 // Create IPv4 UDP endpoint 1013 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 1014 1015 // Bind to wildcard. 1016 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1017 c.T.Fatalf("Bind failed: %s", err) 1018 } 1019 1020 testRead(c, context.UnicastV4) 1021 1022 var want uint64 = 1 1023 if got := c.Stack.Stats().UDP.PacketsReceived.Value(); got != want { 1024 c.T.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) 1025 } 1026 } 1027 1028 func TestReadRecvOriginalDstAddr(t *testing.T) { 1029 tests := []struct { 1030 name string 1031 proto tcpip.NetworkProtocolNumber 1032 flow context.TestFlow 1033 expectedOriginalDstAddr tcpip.FullAddress 1034 }{ 1035 { 1036 name: "IPv4 unicast", 1037 proto: header.IPv4ProtocolNumber, 1038 flow: context.UnicastV4, 1039 expectedOriginalDstAddr: tcpip.FullAddress{NIC: context.NICID, Addr: context.StackAddr, Port: context.StackPort}, 1040 }, 1041 { 1042 name: "IPv4 multicast", 1043 proto: header.IPv4ProtocolNumber, 1044 flow: context.MulticastV4, 1045 // This should actually be a unicast address assigned to the interface. 1046 // 1047 // TODO(gvisor.dev/issue/3556): This check is validating incorrect 1048 // behaviour. We still include the test so that once the bug is resolved, 1049 // this test will start to fail and the individual tasked with fixing this 1050 // bug knows to also fix this test :). 1051 expectedOriginalDstAddr: tcpip.FullAddress{NIC: context.NICID, Addr: context.MulticastAddr, Port: context.StackPort}, 1052 }, 1053 { 1054 name: "IPv4 broadcast", 1055 proto: header.IPv4ProtocolNumber, 1056 flow: context.Broadcast, 1057 // This should actually be a unicast address assigned to the interface. 1058 // 1059 // TODO(gvisor.dev/issue/3556): This check is validating incorrect 1060 // behaviour. We still include the test so that once the bug is resolved, 1061 // this test will start to fail and the individual tasked with fixing this 1062 // bug knows to also fix this test :). 1063 expectedOriginalDstAddr: tcpip.FullAddress{NIC: context.NICID, Addr: context.BroadcastAddr, Port: context.StackPort}, 1064 }, 1065 { 1066 name: "IPv6 unicast", 1067 proto: header.IPv6ProtocolNumber, 1068 flow: context.UnicastV6, 1069 expectedOriginalDstAddr: tcpip.FullAddress{NIC: context.NICID, Addr: context.StackV6Addr, Port: context.StackPort}, 1070 }, 1071 { 1072 name: "IPv6 multicast", 1073 proto: header.IPv6ProtocolNumber, 1074 flow: context.MulticastV6, 1075 // This should actually be a unicast address assigned to the interface. 1076 // 1077 // TODO(gvisor.dev/issue/3556): This check is validating incorrect 1078 // behaviour. We still include the test so that once the bug is resolved, 1079 // this test will start to fail and the individual tasked with fixing this 1080 // bug knows to also fix this test :). 1081 expectedOriginalDstAddr: tcpip.FullAddress{NIC: context.NICID, Addr: context.MulticastV6Addr, Port: context.StackPort}, 1082 }, 1083 } 1084 1085 for _, test := range tests { 1086 t.Run(test.name, func(t *testing.T) { 1087 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1088 defer c.Cleanup() 1089 1090 c.CreateEndpoint(test.proto, udp.ProtocolNumber) 1091 1092 bindAddr := tcpip.FullAddress{Port: context.StackPort} 1093 if err := c.EP.Bind(bindAddr); err != nil { 1094 t.Fatalf("Bind(%#v): %s", bindAddr, err) 1095 } 1096 1097 if test.flow.IsMulticast() { 1098 ifoptSet := tcpip.AddMembershipOption{NIC: context.NICID, MulticastAddr: test.flow.GetMulticastAddr()} 1099 if err := c.EP.SetSockOpt(&ifoptSet); err != nil { 1100 c.T.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) 1101 } 1102 } 1103 1104 c.EP.SocketOptions().SetReceiveOriginalDstAddress(true) 1105 1106 testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) 1107 1108 if got := c.Stack.Stats().UDP.PacketsReceived.Value(); got != 1 { 1109 t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) 1110 } 1111 }) 1112 } 1113 } 1114 1115 func TestWriteIncrementsPacketsSent(t *testing.T) { 1116 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1117 defer c.Cleanup() 1118 1119 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 1120 1121 testDualWrite(c) 1122 1123 var want uint64 = 2 1124 if got := c.Stack.Stats().UDP.PacketsSent.Value(); got != want { 1125 c.T.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) 1126 } 1127 } 1128 1129 func TestNoChecksum(t *testing.T) { 1130 for _, writeOpSequence := range writeOpSequences { 1131 for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV6} { 1132 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1133 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1134 defer c.Cleanup() 1135 1136 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 1137 1138 // Disable the checksum generation. 1139 c.EP.SocketOptions().SetNoChecksum(true) 1140 // This option is effective on IPv4 only. 1141 testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.UDP(checker.NoChecksum(flow.IsV4()))) 1142 1143 // Enable the checksum generation. 1144 c.EP.SocketOptions().SetNoChecksum(false) 1145 testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.UDP(checker.NoChecksum(false))) 1146 }) 1147 } 1148 } 1149 } 1150 1151 var _ stack.NetworkInterface = (*testInterface)(nil) 1152 1153 type testInterface struct { 1154 stack.NetworkInterface 1155 } 1156 1157 func (*testInterface) ID() tcpip.NICID { 1158 return 0 1159 } 1160 1161 func (*testInterface) Enabled() bool { 1162 return true 1163 } 1164 1165 func TestDefaultTTL(t *testing.T) { 1166 for _, writeOpSequence := range writeOpSequences { 1167 for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV4in6, context.UnicastV6, context.UnicastV6Only, context.Broadcast, context.BroadcastIn6} { 1168 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1169 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1170 defer c.Cleanup() 1171 1172 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 1173 proto := c.Stack.NetworkProtocolInstance(flow.NetProto()) 1174 if proto == nil { 1175 t.Fatalf("c.Stack.NetworkProtocolInstance(flow.NetProto()) did not return a protocol") 1176 } 1177 1178 var initialDefaultTTL tcpip.DefaultTTLOption 1179 if err := proto.Option(&initialDefaultTTL); err != nil { 1180 t.Fatalf("proto.Option(&initialDefaultTTL) (%T) failed: %s", initialDefaultTTL, err) 1181 } 1182 testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TTL(uint8(initialDefaultTTL))) 1183 1184 newDefaultTTL := tcpip.DefaultTTLOption(initialDefaultTTL + 1) 1185 if err := proto.SetOption(&newDefaultTTL); err != nil { 1186 c.T.Fatalf("proto.SetOption(&%T(%d))) failed: %s", newDefaultTTL, newDefaultTTL, err) 1187 } 1188 testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TTL(uint8(newDefaultTTL))) 1189 }) 1190 } 1191 } 1192 } 1193 1194 func TestSetNonMulticastTTL(t *testing.T) { 1195 for _, writeOpSequence := range writeOpSequences { 1196 for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV4in6, context.UnicastV6, context.UnicastV6Only, context.Broadcast, context.BroadcastIn6} { 1197 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1198 for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { 1199 t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { 1200 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1201 defer c.Cleanup() 1202 1203 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 1204 1205 var relevantOpt tcpip.SockOptInt 1206 var irrelevantOpt tcpip.SockOptInt 1207 if flow.IsV4() { 1208 relevantOpt = tcpip.IPv4TTLOption 1209 irrelevantOpt = tcpip.IPv6HopLimitOption 1210 } else { 1211 relevantOpt = tcpip.IPv6HopLimitOption 1212 irrelevantOpt = tcpip.IPv4TTLOption 1213 } 1214 if err := c.EP.SetSockOptInt(relevantOpt, int(wantTTL)); err != nil { 1215 c.T.Fatalf("SetSockOptInt(%d, %d) failed: %s", relevantOpt, wantTTL, err) 1216 } 1217 // Set a different ttl/hoplimit for the unused protocol, showing that 1218 // it does not affect the other protocol. 1219 if err := c.EP.SetSockOptInt(irrelevantOpt, int(wantTTL+1)); err != nil { 1220 c.T.Fatalf("SetSockOptInt(%d, %d) failed: %s", irrelevantOpt, wantTTL, err) 1221 } 1222 1223 testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TTL(wantTTL)) 1224 }) 1225 } 1226 }) 1227 } 1228 } 1229 } 1230 1231 func TestSetMulticastTTL(t *testing.T) { 1232 for _, writeOpSequence := range writeOpSequences { 1233 for _, flow := range []context.TestFlow{context.MulticastV4, context.MulticastV4in6, context.MulticastV6} { 1234 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1235 for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { 1236 t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { 1237 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1238 defer c.Cleanup() 1239 1240 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 1241 1242 if err := c.EP.SetSockOptInt(tcpip.MulticastTTLOption, int(wantTTL)); err != nil { 1243 c.T.Fatalf("SetSockOptInt failed: %s", err) 1244 } 1245 1246 testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TTL(wantTTL)) 1247 }) 1248 } 1249 }) 1250 } 1251 } 1252 } 1253 1254 var v4PacketFlows = [...]context.TestFlow{context.UnicastV4, context.MulticastV4, context.Broadcast, context.UnicastV4in6, context.MulticastV4in6, context.BroadcastIn6} 1255 1256 func TestSetTOS(t *testing.T) { 1257 for _, writeOpSequence := range writeOpSequences { 1258 for _, flow := range v4PacketFlows { 1259 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1260 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1261 defer c.Cleanup() 1262 1263 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 1264 1265 const tos = testTOS 1266 v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) 1267 if err != nil { 1268 c.T.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) 1269 } 1270 // Test for expected default value. 1271 if v != 0 { 1272 c.T.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) 1273 } 1274 1275 if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { 1276 c.T.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) 1277 } 1278 1279 v, err = c.EP.GetSockOptInt(tcpip.IPv4TOSOption) 1280 if err != nil { 1281 c.T.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) 1282 } 1283 1284 if v != tos { 1285 c.T.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) 1286 } 1287 1288 testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TOS(tos, 0)) 1289 }) 1290 } 1291 } 1292 } 1293 1294 var v6PacketFlows = [...]context.TestFlow{context.UnicastV6, context.UnicastV6Only, context.MulticastV6} 1295 1296 func TestSetTClass(t *testing.T) { 1297 for _, writeOpSequence := range writeOpSequences { 1298 for _, flow := range v6PacketFlows { 1299 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1300 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1301 defer c.Cleanup() 1302 1303 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 1304 1305 const tClass = testTOS 1306 v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) 1307 if err != nil { 1308 c.T.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) 1309 } 1310 // Test for expected default value. 1311 if v != 0 { 1312 c.T.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) 1313 } 1314 1315 if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { 1316 c.T.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) 1317 } 1318 1319 v, err = c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) 1320 if err != nil { 1321 c.T.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) 1322 } 1323 1324 if v != tClass { 1325 c.T.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) 1326 } 1327 1328 // The header getter for TClass is called TOS, so use that checker. 1329 testWriteOpSequenceSucceeds(c, flow, writeOpSequence, checker.TOS(tClass, 0)) 1330 }) 1331 } 1332 } 1333 } 1334 1335 func TestReceiveControlMessage(t *testing.T) { 1336 for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV6, context.UnicastV6Only, context.MulticastV4, context.MulticastV6, context.MulticastV6Only, context.Broadcast} { 1337 t.Run(flow.String(), func(t *testing.T) { 1338 for _, test := range []struct { 1339 name string 1340 optionProtocol tcpip.NetworkProtocolNumber 1341 getReceiveOption func(tcpip.Endpoint) bool 1342 setReceiveOption func(tcpip.Endpoint, bool) 1343 presenceChecker checker.ControlMessagesChecker 1344 absenceChecker checker.ControlMessagesChecker 1345 }{ 1346 { 1347 name: "TOS", 1348 optionProtocol: header.IPv4ProtocolNumber, 1349 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTOS() }, 1350 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTOS(value) }, 1351 presenceChecker: checker.ReceiveTOS(testTOS), 1352 absenceChecker: checker.NoTOSReceived(), 1353 }, 1354 { 1355 name: "TClass", 1356 optionProtocol: header.IPv6ProtocolNumber, 1357 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTClass() }, 1358 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTClass(value) }, 1359 presenceChecker: checker.ReceiveTClass(testTOS), 1360 absenceChecker: checker.NoTClassReceived(), 1361 }, 1362 { 1363 name: "TTL", 1364 optionProtocol: header.IPv4ProtocolNumber, 1365 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTTL() }, 1366 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTTL(value) }, 1367 presenceChecker: checker.ReceiveTTL(testTTL), 1368 absenceChecker: checker.NoTTLReceived(), 1369 }, 1370 { 1371 name: "HopLimit", 1372 optionProtocol: header.IPv6ProtocolNumber, 1373 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveHopLimit() }, 1374 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveHopLimit(value) }, 1375 presenceChecker: checker.ReceiveHopLimit(testTTL), 1376 absenceChecker: checker.NoHopLimitReceived(), 1377 }, 1378 { 1379 name: "PacketInfo", 1380 optionProtocol: header.IPv4ProtocolNumber, 1381 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceivePacketInfo() }, 1382 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceivePacketInfo(value) }, 1383 presenceChecker: func() checker.ControlMessagesChecker { 1384 h := flow.MakeHeader4Tuple(context.Incoming) 1385 return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ 1386 NIC: context.NICID, 1387 // TODO(https://gvisor.dev/issue/3556): Expect the NIC's address 1388 // instead of the header destination address for the LocalAddr 1389 // field. 1390 LocalAddr: h.Dst.Addr, 1391 DestinationAddr: h.Dst.Addr, 1392 }) 1393 }(), 1394 absenceChecker: checker.NoIPPacketInfoReceived(), 1395 }, 1396 { 1397 name: "IPv6PacketInfo", 1398 optionProtocol: header.IPv6ProtocolNumber, 1399 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetIPv6ReceivePacketInfo() }, 1400 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetIPv6ReceivePacketInfo(value) }, 1401 presenceChecker: func() checker.ControlMessagesChecker { 1402 h := flow.MakeHeader4Tuple(context.Incoming) 1403 return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{ 1404 NIC: context.NICID, 1405 Addr: h.Dst.Addr, 1406 }) 1407 }(), 1408 absenceChecker: checker.NoIPv6PacketInfoReceived(), 1409 }, 1410 } { 1411 t.Run(test.name, func(t *testing.T) { 1412 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol}) 1413 defer c.Cleanup() 1414 1415 c.CreateEndpointForFlow(flow, udp.ProtocolNumber) 1416 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1417 c.T.Fatalf("Bind failed: %s", err) 1418 } 1419 if flow.IsMulticast() { 1420 netProto := flow.NetProto() 1421 addr := flow.GetMulticastAddr() 1422 if err := c.Stack.JoinGroup(netProto, context.NICID, addr); err != nil { 1423 c.T.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, context.NICID, addr, err) 1424 } 1425 } 1426 1427 payload := newRandomPayload(arbitraryPayloadSize) 1428 buf := context.BuildUDPPacket(payload, flow, context.Incoming, testTOS, testTTL, false) 1429 1430 if test.getReceiveOption(c.EP) { 1431 t.Fatal("got getReceiveOption() = true, want = false") 1432 } 1433 1434 test.setReceiveOption(c.EP, true) 1435 if !test.getReceiveOption(c.EP) { 1436 t.Fatal("got getReceiveOption() = false, want = true") 1437 } 1438 1439 c.InjectPacket(flow.NetProto(), buf) 1440 if flow.NetProto() == test.optionProtocol { 1441 c.ReadFromEndpointExpectSuccess(payload, flow, test.presenceChecker) 1442 } else { 1443 c.ReadFromEndpointExpectSuccess(payload, flow, test.absenceChecker) 1444 } 1445 }) 1446 } 1447 }) 1448 } 1449 } 1450 1451 func TestMulticastInterfaceOption(t *testing.T) { 1452 for _, flow := range []context.TestFlow{context.MulticastV4, context.MulticastV4in6, context.MulticastV6, context.MulticastV6Only} { 1453 t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { 1454 for _, bindTyp := range []string{"bound", "unbound"} { 1455 t.Run(bindTyp, func(t *testing.T) { 1456 for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { 1457 t.Run(optTyp, func(t *testing.T) { 1458 h := flow.MakeHeader4Tuple(context.Outgoing) 1459 mcastAddr := h.Dst.Addr 1460 localIfAddr := h.Src.Addr 1461 1462 var ifoptSet tcpip.MulticastInterfaceOption 1463 switch optTyp { 1464 case "use local-addr": 1465 ifoptSet.InterfaceAddr = localIfAddr 1466 case "use NICID": 1467 ifoptSet.NIC = 1 1468 case "use local-addr and NIC": 1469 ifoptSet.InterfaceAddr = localIfAddr 1470 ifoptSet.NIC = 1 1471 default: 1472 t.Fatal("unknown test variant") 1473 } 1474 1475 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1476 defer c.Cleanup() 1477 1478 c.CreateEndpoint(flow.SockProto(), udp.ProtocolNumber) 1479 1480 if bindTyp == "bound" { 1481 // Bind the socket by connecting to the multicast address. 1482 // This may have an influence on how the multicast interface 1483 // is set. 1484 addr := tcpip.FullAddress{ 1485 Addr: flow.MapAddrIfApplicable(mcastAddr), 1486 Port: context.StackPort, 1487 } 1488 if err := c.EP.Connect(addr); err != nil { 1489 c.T.Fatalf("Connect failed: %s", err) 1490 } 1491 } 1492 1493 if err := c.EP.SetSockOpt(&ifoptSet); err != nil { 1494 c.T.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) 1495 } 1496 1497 // Verify multicast interface addr and NIC were set correctly. 1498 // Note that NIC must be 1 since this is our outgoing interface. 1499 var ifoptGot tcpip.MulticastInterfaceOption 1500 if err := c.EP.GetSockOpt(&ifoptGot); err != nil { 1501 c.T.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err) 1502 } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant { 1503 c.T.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant) 1504 } 1505 }) 1506 } 1507 }) 1508 } 1509 }) 1510 } 1511 } 1512 1513 // TestV4UnknownDestination verifies that we generate an ICMPv4 Destination 1514 // Unreachable message when a udp datagram is received on ports for which there 1515 // is no bound udp socket. 1516 func TestV4UnknownDestination(t *testing.T) { 1517 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1518 defer c.Cleanup() 1519 1520 testCases := []struct { 1521 flow context.TestFlow 1522 icmpRequired bool 1523 // largePayload if true, will result in a payload large enough 1524 // so that the final generated IPv4 packet is larger than 1525 // header.IPv4MinimumProcessableDatagramSize. 1526 largePayload bool 1527 // badChecksum if true, will set an invalid checksum in the 1528 // header. 1529 badChecksum bool 1530 }{ 1531 {context.UnicastV4, true, false, false}, 1532 {context.UnicastV4, true, true, false}, 1533 {context.UnicastV4, false, false, true}, 1534 {context.UnicastV4, false, true, true}, 1535 {context.MulticastV4, false, false, false}, 1536 {context.MulticastV4, false, true, false}, 1537 {context.Broadcast, false, false, false}, 1538 {context.Broadcast, false, true, false}, 1539 } 1540 checksumErrors := uint64(0) 1541 for _, tc := range testCases { 1542 t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { 1543 payloadSize := arbitraryPayloadSize 1544 if tc.largePayload { 1545 payloadSize += header.IPv4MinimumProcessableDatagramSize 1546 } 1547 payload := newRandomPayload(payloadSize) 1548 c.InjectPacket(tc.flow.NetProto(), context.BuildUDPPacket(payload, tc.flow, context.Incoming, testTOS, testTTL, tc.badChecksum)) 1549 if tc.badChecksum { 1550 checksumErrors++ 1551 if got, want := c.Stack.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { 1552 t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 1553 } 1554 } 1555 if !tc.icmpRequired { 1556 if p := c.LinkEP.Read(); p != nil { 1557 t.Fatalf("unexpected packet received: %+v", p) 1558 } 1559 return 1560 } 1561 1562 // ICMP required. 1563 p := c.LinkEP.Read() 1564 if p == nil { 1565 t.Fatalf("packet wasn't written out") 1566 } 1567 1568 buf := p.ToBuffer() 1569 defer buf.Release() 1570 p.DecRef() 1571 pkt := buf.Flatten() 1572 if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { 1573 t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) 1574 } 1575 1576 hdr := buffer.NewViewWithData(pkt) 1577 defer hdr.Release() 1578 checker.IPv4(t, hdr, checker.ICMPv4( 1579 checker.ICMPv4Type(header.ICMPv4DstUnreachable), 1580 checker.ICMPv4Code(header.ICMPv4PortUnreachable))) 1581 1582 // We need to compare the included data part of the UDP packet that is in 1583 // the ICMP packet with the matching original data. 1584 icmpPkt := header.ICMPv4(header.IPv4(hdr.AsSlice()).Payload()) 1585 payloadIPHeader := header.IPv4(icmpPkt.Payload()) 1586 incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize 1587 wantLen := len(payload) 1588 if tc.largePayload { 1589 // To work out the data size we need to simulate what the sender would 1590 // have done. The wanted size is the total available minus the sum of 1591 // the headers in the UDP AND ICMP packets, given that we know the test 1592 // had only a minimal IP header but the ICMP sender will have allowed 1593 // for a maximally sized packet header. 1594 wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength 1595 } 1596 1597 // In the case of large payloads the IP packet may be truncated. Update 1598 // the length field before retrieving the udp datagram payload. 1599 // Add back the two headers within the payload. 1600 payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) 1601 1602 origDgram := header.UDP(payloadIPHeader.Payload()) 1603 if got, want := len(origDgram.Payload()), wantLen; got != want { 1604 t.Fatalf("unexpected payload length got: %d, want: %d", got, want) 1605 } 1606 if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { 1607 t.Fatalf("unexpected payload got: %d, want: %d", got, want) 1608 } 1609 }) 1610 } 1611 } 1612 1613 // TestV6UnknownDestination verifies that we generate an ICMPv6 Destination 1614 // Unreachable message when a udp datagram is received on ports for which there 1615 // is no bound udp socket. 1616 func TestV6UnknownDestination(t *testing.T) { 1617 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1618 defer c.Cleanup() 1619 1620 testCases := []struct { 1621 flow context.TestFlow 1622 icmpRequired bool 1623 // largePayload if true will result in a payload large enough to 1624 // create an IPv6 packet > header.IPv6MinimumMTU bytes. 1625 largePayload bool 1626 // badChecksum if true, will set an invalid checksum in the 1627 // header. 1628 badChecksum bool 1629 }{ 1630 {context.UnicastV6, true, false, false}, 1631 {context.UnicastV6, true, true, false}, 1632 {context.UnicastV6, false, false, true}, 1633 {context.UnicastV6, false, true, true}, 1634 {context.MulticastV6, false, false, false}, 1635 {context.MulticastV6, false, true, false}, 1636 } 1637 checksumErrors := uint64(0) 1638 for _, tc := range testCases { 1639 t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { 1640 payloadSize := arbitraryPayloadSize 1641 if tc.largePayload { 1642 payloadSize += header.IPv6MinimumMTU 1643 } 1644 payload := newRandomPayload(payloadSize) 1645 c.InjectPacket(tc.flow.NetProto(), context.BuildUDPPacket(payload, tc.flow, context.Incoming, testTOS, testTTL, tc.badChecksum)) 1646 if tc.badChecksum { 1647 checksumErrors++ 1648 if got, want := c.Stack.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { 1649 t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 1650 } 1651 } 1652 if !tc.icmpRequired { 1653 if p := c.LinkEP.Read(); p != nil { 1654 t.Fatalf("unexpected packet received: %+v", p) 1655 } 1656 return 1657 } 1658 1659 // ICMP required. 1660 p := c.LinkEP.Read() 1661 if p == nil { 1662 t.Fatalf("packet wasn't written out") 1663 } 1664 1665 buf := p.ToBuffer() 1666 defer buf.Release() 1667 p.DecRef() 1668 pkt := buf.Flatten() 1669 if got, want := len(pkt), header.IPv6MinimumMTU; got > want { 1670 t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) 1671 } 1672 1673 hdr := buffer.NewViewWithData(pkt) 1674 defer hdr.Release() 1675 checker.IPv6(t, hdr, checker.ICMPv6( 1676 checker.ICMPv6Type(header.ICMPv6DstUnreachable), 1677 checker.ICMPv6Code(header.ICMPv6PortUnreachable))) 1678 1679 icmpPkt := header.ICMPv6(header.IPv6(hdr.AsSlice()).Payload()) 1680 payloadIPHeader := header.IPv6(icmpPkt.Payload()) 1681 wantLen := len(payload) 1682 if tc.largePayload { 1683 wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize 1684 } 1685 // In case of large payloads the IP packet may be truncated. Update 1686 // the length field before retrieving the udp datagram payload. 1687 payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) 1688 1689 origDgram := header.UDP(payloadIPHeader.Payload()) 1690 if got, want := len(origDgram.Payload()), wantLen; got != want { 1691 t.Fatalf("unexpected payload length got: %d, want: %d", got, want) 1692 } 1693 if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { 1694 t.Fatalf("unexpected payload got: %v, want: %v", got, want) 1695 } 1696 }) 1697 } 1698 } 1699 1700 // TestIncrementMalformedPacketsReceived verifies if the malformed received 1701 // global and endpoint stats are incremented. 1702 func TestIncrementMalformedPacketsReceived(t *testing.T) { 1703 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1704 defer c.Cleanup() 1705 1706 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 1707 // Bind to wildcard. 1708 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1709 c.T.Fatalf("Bind failed: %s", err) 1710 } 1711 1712 payload := newRandomPayload(arbitraryPayloadSize) 1713 h := context.UnicastV6.MakeHeader4Tuple(context.Incoming) 1714 buf := context.BuildV6UDPPacket(payload, h, testTOS, testTTL, false) 1715 1716 // Invalidate the UDP header length field. 1717 u := header.UDP(buf[header.IPv6MinimumSize:]) 1718 u.SetLength(u.Length() + 1) 1719 c.InjectPacket(header.IPv6ProtocolNumber, buf) 1720 1721 const want = 1 1722 if got := c.Stack.Stats().UDP.MalformedPacketsReceived.Value(); got != want { 1723 t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) 1724 } 1725 if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { 1726 t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) 1727 } 1728 } 1729 1730 // TestShortHeader verifies that when a packet with a too-short UDP header is 1731 // received, the malformed received global stat gets incremented. 1732 func TestShortHeader(t *testing.T) { 1733 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1734 defer c.Cleanup() 1735 1736 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 1737 // Bind to wildcard. 1738 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1739 c.T.Fatalf("Bind failed: %s", err) 1740 } 1741 1742 h := context.UnicastV6.MakeHeader4Tuple(context.Incoming) 1743 1744 // Allocate a buffer for an IPv6 and too-short UDP header. 1745 const udpSize = header.UDPMinimumSize - 1 1746 buf := make([]byte, header.IPv6MinimumSize+udpSize) 1747 // Initialize the IP header. 1748 ip := header.IPv6(buf) 1749 ip.Encode(&header.IPv6Fields{ 1750 TrafficClass: testTOS, 1751 PayloadLength: uint16(udpSize), 1752 TransportProtocol: udp.ProtocolNumber, 1753 HopLimit: testTTL, 1754 SrcAddr: h.Src.Addr, 1755 DstAddr: h.Dst.Addr, 1756 }) 1757 1758 // Initialize the UDP header. 1759 udpHdr := header.UDP(make([]byte, header.UDPMinimumSize)) 1760 udpHdr.Encode(&header.UDPFields{ 1761 SrcPort: h.Src.Port, 1762 DstPort: h.Dst.Port, 1763 Length: header.UDPMinimumSize, 1764 }) 1765 // Calculate the UDP pseudo-header checksum. 1766 xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.Src.Addr, h.Dst.Addr, uint16(len(udpHdr))) 1767 udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) 1768 // Copy all but the last byte of the UDP header into the packet. 1769 copy(buf[header.IPv6MinimumSize:], udpHdr) 1770 1771 // Inject packet. 1772 c.InjectPacket(header.IPv6ProtocolNumber, buf) 1773 1774 if got, want := c.Stack.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want { 1775 t.Errorf("got c.Stack.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want) 1776 } 1777 } 1778 1779 // TestBadChecksumErrors verifies if a checksum error is detected, 1780 // global and endpoint stats are incremented. 1781 func TestBadChecksumErrors(t *testing.T) { 1782 for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV6} { 1783 t.Run(flow.String(), func(t *testing.T) { 1784 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1785 defer c.Cleanup() 1786 1787 c.CreateEndpoint(flow.SockProto(), udp.ProtocolNumber) 1788 // Bind to wildcard. 1789 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1790 c.T.Fatalf("Bind failed: %s", err) 1791 } 1792 1793 c.InjectPacket(flow.NetProto(), context.BuildUDPPacket(newRandomPayload(arbitraryPayloadSize), flow, context.Incoming, testTOS, testTTL, true)) 1794 1795 const want = 1 1796 if got := c.Stack.Stats().UDP.ChecksumErrors.Value(); got != want { 1797 t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 1798 } 1799 if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { 1800 t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) 1801 } 1802 }) 1803 } 1804 } 1805 1806 // TestPayloadModifiedV4 verifies if a checksum error is detected, 1807 // global and endpoint stats are incremented. 1808 func TestPayloadModifiedV4(t *testing.T) { 1809 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1810 defer c.Cleanup() 1811 1812 c.CreateEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber) 1813 // Bind to wildcard. 1814 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1815 c.T.Fatalf("Bind failed: %s", err) 1816 } 1817 1818 payload := newRandomPayload(arbitraryPayloadSize) 1819 h := context.UnicastV4.MakeHeader4Tuple(context.Incoming) 1820 buf := context.BuildV4UDPPacket(payload, h, testTOS, testTTL, false) 1821 // Modify the payload so that the checksum value in the UDP header will be 1822 // incorrect. 1823 buf[len(buf)-1]++ 1824 c.InjectPacket(header.IPv4ProtocolNumber, buf) 1825 1826 const want = 1 1827 if got := c.Stack.Stats().UDP.ChecksumErrors.Value(); got != want { 1828 t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 1829 } 1830 if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { 1831 t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) 1832 } 1833 } 1834 1835 // TestPayloadModifiedV6 verifies if a checksum error is detected, 1836 // global and endpoint stats are incremented. 1837 func TestPayloadModifiedV6(t *testing.T) { 1838 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1839 defer c.Cleanup() 1840 1841 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 1842 // Bind to wildcard. 1843 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1844 c.T.Fatalf("Bind failed: %s", err) 1845 } 1846 1847 payload := newRandomPayload(arbitraryPayloadSize) 1848 h := context.UnicastV6.MakeHeader4Tuple(context.Incoming) 1849 buf := context.BuildV6UDPPacket(payload, h, testTOS, testTTL, false) 1850 // Modify the payload so that the checksum value in the UDP header will be 1851 // incorrect. 1852 buf[len(buf)-1]++ 1853 c.InjectPacket(header.IPv6ProtocolNumber, buf) 1854 1855 const want = 1 1856 if got := c.Stack.Stats().UDP.ChecksumErrors.Value(); got != want { 1857 t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 1858 } 1859 if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { 1860 t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) 1861 } 1862 } 1863 1864 // TestChecksumZeroV4 verifies if the checksum value is zero, global and 1865 // endpoint states are *not* incremented (UDP checksum is optional on IPv4). 1866 func TestChecksumZeroV4(t *testing.T) { 1867 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1868 defer c.Cleanup() 1869 1870 c.CreateEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber) 1871 // Bind to wildcard. 1872 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1873 c.T.Fatalf("Bind failed: %s", err) 1874 } 1875 1876 payload := newRandomPayload(arbitraryPayloadSize) 1877 h := context.UnicastV4.MakeHeader4Tuple(context.Incoming) 1878 buf := context.BuildV4UDPPacket(payload, h, testTOS, testTTL, false) 1879 // Set the checksum field in the UDP header to zero. 1880 u := header.UDP(buf[header.IPv4MinimumSize:]) 1881 u.SetChecksum(0) 1882 c.InjectPacket(header.IPv4ProtocolNumber, buf) 1883 1884 const want = 0 1885 if got := c.Stack.Stats().UDP.ChecksumErrors.Value(); got != want { 1886 t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 1887 } 1888 if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { 1889 t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) 1890 } 1891 } 1892 1893 // TestChecksumZeroV6 verifies if the checksum value is zero, global and 1894 // endpoint states are incremented (UDP checksum is *not* optional on IPv6). 1895 func TestChecksumZeroV6(t *testing.T) { 1896 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1897 defer c.Cleanup() 1898 1899 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 1900 // Bind to wildcard. 1901 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1902 c.T.Fatalf("Bind failed: %s", err) 1903 } 1904 1905 payload := newRandomPayload(arbitraryPayloadSize) 1906 h := context.UnicastV6.MakeHeader4Tuple(context.Incoming) 1907 buf := context.BuildV6UDPPacket(payload, h, testTOS, testTTL, false) 1908 // Set the checksum field in the UDP header to zero. 1909 u := header.UDP(buf[header.IPv6MinimumSize:]) 1910 u.SetChecksum(0) 1911 c.InjectPacket(header.IPv6ProtocolNumber, buf) 1912 1913 const want = 1 1914 if got := c.Stack.Stats().UDP.ChecksumErrors.Value(); got != want { 1915 t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) 1916 } 1917 if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { 1918 t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) 1919 } 1920 } 1921 1922 // TestShutdownRead verifies endpoint read shutdown and error 1923 // stats increment on packet receive. 1924 func TestShutdownRead(t *testing.T) { 1925 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1926 defer c.Cleanup() 1927 1928 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 1929 1930 // Bind to wildcard. 1931 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 1932 c.T.Fatalf("Bind failed: %s", err) 1933 } 1934 1935 if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil { 1936 c.T.Fatalf("Connect failed: %s", err) 1937 } 1938 1939 if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { 1940 t.Fatalf("Shutdown failed: %s", err) 1941 } 1942 1943 testFailingRead(c, context.UnicastV6, true /* expectReadError */) 1944 1945 var want uint64 = 1 1946 if got := c.Stack.Stats().UDP.ReceiveBufferErrors.Value(); got != want { 1947 t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) 1948 } 1949 if got := c.EP.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { 1950 t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) 1951 } 1952 } 1953 1954 // TestShutdownWrite verifies endpoint write shutdown and error 1955 // stats increment on packet write. 1956 func TestShutdownWrite(t *testing.T) { 1957 for _, writeOpSequence := range writeOpSequences { 1958 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 1959 1960 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 1961 1962 if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil { 1963 c.T.Fatalf("Connect failed: %s", err) 1964 } 1965 1966 if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { 1967 t.Fatalf("Shutdown failed: %s", err) 1968 } 1969 1970 testWriteOpSequenceFails(c, context.UnicastV6, writeOpSequence, &tcpip.ErrClosedForSend{}) 1971 c.Cleanup() 1972 } 1973 } 1974 1975 func TestOutgoingSubnetBroadcast(t *testing.T) { 1976 const nicID1 = 1 1977 1978 ipv4Addr := tcpip.AddressWithPrefix{ 1979 Address: tcpip.AddrFromSlice([]byte("\xc0\xa8\x01\x3a")), 1980 PrefixLen: 24, 1981 } 1982 ipv4Subnet := ipv4Addr.Subnet() 1983 ipv4SubnetBcast := ipv4Subnet.Broadcast() 1984 ipv4Gateway := testutil.MustParse4("192.168.1.1") 1985 ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ 1986 Address: tcpip.AddrFromSlice([]byte("\xc0\xa8\x01\x3a")), 1987 PrefixLen: 31, 1988 } 1989 ipv4Subnet31 := ipv4AddrPrefix31.Subnet() 1990 ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() 1991 ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ 1992 Address: tcpip.AddrFromSlice([]byte("\xc0\xa8\x01\x3a")), 1993 PrefixLen: 32, 1994 } 1995 ipv4Subnet32 := ipv4AddrPrefix32.Subnet() 1996 ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() 1997 ipv6Addr := tcpip.AddressWithPrefix{ 1998 Address: tcpip.AddrFromSlice([]byte("\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")), 1999 PrefixLen: 64, 2000 } 2001 ipv6Subnet := ipv6Addr.Subnet() 2002 ipv6SubnetBcast := ipv6Subnet.Broadcast() 2003 remNetAddr := tcpip.AddressWithPrefix{ 2004 Address: tcpip.AddrFromSlice([]byte("\x64\x0a\x7b\x18")), 2005 PrefixLen: 24, 2006 } 2007 remNetSubnet := remNetAddr.Subnet() 2008 remNetSubnetBcast := remNetSubnet.Broadcast() 2009 2010 tests := []struct { 2011 name string 2012 nicAddr tcpip.ProtocolAddress 2013 routes []tcpip.Route 2014 remoteAddr tcpip.Address 2015 requiresBroadcastOpt bool 2016 }{ 2017 { 2018 name: "IPv4 Broadcast to local subnet", 2019 nicAddr: tcpip.ProtocolAddress{ 2020 Protocol: header.IPv4ProtocolNumber, 2021 AddressWithPrefix: ipv4Addr, 2022 }, 2023 routes: []tcpip.Route{ 2024 { 2025 Destination: ipv4Subnet, 2026 NIC: nicID1, 2027 }, 2028 }, 2029 remoteAddr: ipv4SubnetBcast, 2030 requiresBroadcastOpt: true, 2031 }, 2032 { 2033 name: "IPv4 Broadcast to local /31 subnet", 2034 nicAddr: tcpip.ProtocolAddress{ 2035 Protocol: header.IPv4ProtocolNumber, 2036 AddressWithPrefix: ipv4AddrPrefix31, 2037 }, 2038 routes: []tcpip.Route{ 2039 { 2040 Destination: ipv4Subnet31, 2041 NIC: nicID1, 2042 }, 2043 }, 2044 remoteAddr: ipv4Subnet31Bcast, 2045 requiresBroadcastOpt: false, 2046 }, 2047 { 2048 name: "IPv4 Broadcast to local /32 subnet", 2049 nicAddr: tcpip.ProtocolAddress{ 2050 Protocol: header.IPv4ProtocolNumber, 2051 AddressWithPrefix: ipv4AddrPrefix32, 2052 }, 2053 routes: []tcpip.Route{ 2054 { 2055 Destination: ipv4Subnet32, 2056 NIC: nicID1, 2057 }, 2058 }, 2059 remoteAddr: ipv4Subnet32Bcast, 2060 requiresBroadcastOpt: false, 2061 }, 2062 // IPv6 has no notion of a broadcast. 2063 { 2064 name: "IPv6 'Broadcast' to local subnet", 2065 nicAddr: tcpip.ProtocolAddress{ 2066 Protocol: header.IPv6ProtocolNumber, 2067 AddressWithPrefix: ipv6Addr, 2068 }, 2069 routes: []tcpip.Route{ 2070 { 2071 Destination: ipv6Subnet, 2072 NIC: nicID1, 2073 }, 2074 }, 2075 remoteAddr: ipv6SubnetBcast, 2076 requiresBroadcastOpt: false, 2077 }, 2078 { 2079 name: "IPv4 Broadcast to remote subnet", 2080 nicAddr: tcpip.ProtocolAddress{ 2081 Protocol: header.IPv4ProtocolNumber, 2082 AddressWithPrefix: ipv4Addr, 2083 }, 2084 routes: []tcpip.Route{ 2085 { 2086 Destination: remNetSubnet, 2087 Gateway: ipv4Gateway, 2088 NIC: nicID1, 2089 }, 2090 }, 2091 remoteAddr: remNetSubnetBcast, 2092 // TODO(gvisor.dev/issue/3938): Once we support marking a route as 2093 // broadcast, this test should require the broadcast option to be set. 2094 requiresBroadcastOpt: false, 2095 }, 2096 } 2097 2098 for _, test := range tests { 2099 t.Run(test.name, func(t *testing.T) { 2100 s := stack.New(stack.Options{ 2101 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 2102 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 2103 Clock: &faketime.NullClock{}, 2104 }) 2105 defer s.Destroy() 2106 e := channel.New(0, context.DefaultMTU, "") 2107 defer e.Close() 2108 if err := s.CreateNIC(nicID1, e); err != nil { 2109 t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) 2110 } 2111 if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil { 2112 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err) 2113 } 2114 2115 s.SetRouteTable(test.routes) 2116 2117 var netProto tcpip.NetworkProtocolNumber 2118 switch l := test.remoteAddr.Len(); l { 2119 case header.IPv4AddressSize: 2120 netProto = header.IPv4ProtocolNumber 2121 case header.IPv6AddressSize: 2122 netProto = header.IPv6ProtocolNumber 2123 default: 2124 t.Fatalf("got unexpected address length = %d bytes", l) 2125 } 2126 2127 wq := waiter.Queue{} 2128 ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq) 2129 if err != nil { 2130 t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err) 2131 } 2132 defer ep.Close() 2133 2134 var r bytes.Reader 2135 data := []byte{1, 2, 3, 4} 2136 to := tcpip.FullAddress{ 2137 Addr: test.remoteAddr, 2138 Port: 80, 2139 } 2140 opts := tcpip.WriteOptions{To: &to} 2141 expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error { 2142 if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok { 2143 return nil 2144 } 2145 return &tcpip.ErrBroadcastDisabled{} 2146 } 2147 if !test.requiresBroadcastOpt { 2148 expectedErrWithoutBcastOpt = nil 2149 } 2150 2151 r.Reset(data) 2152 { 2153 n, err := ep.Write(&r, opts) 2154 if expectedErrWithoutBcastOpt != nil { 2155 if want := expectedErrWithoutBcastOpt(err); want != nil { 2156 t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) 2157 } 2158 } else if err != nil { 2159 t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) 2160 } 2161 } 2162 2163 ep.SocketOptions().SetBroadcast(true) 2164 2165 r.Reset(data) 2166 if n, err := ep.Write(&r, opts); err != nil { 2167 t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) 2168 } 2169 2170 ep.SocketOptions().SetBroadcast(false) 2171 2172 r.Reset(data) 2173 { 2174 n, err := ep.Write(&r, opts) 2175 if expectedErrWithoutBcastOpt != nil { 2176 if want := expectedErrWithoutBcastOpt(err); want != nil { 2177 t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) 2178 } 2179 } else if err != nil { 2180 t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) 2181 } 2182 } 2183 }) 2184 } 2185 } 2186 2187 func TestChecksumWithZeroValueOnesComplementSum(t *testing.T) { 2188 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol}) 2189 defer c.Cleanup() 2190 2191 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 2192 var writeOpts tcpip.WriteOptions 2193 h := context.UnicastV6.MakeHeader4Tuple(context.Outgoing) 2194 writeDstAddr := context.UnicastV6.MapAddrIfApplicable(h.Dst.Addr) 2195 writeOpts = tcpip.WriteOptions{ 2196 To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.Dst.Port}, 2197 } 2198 2199 // Write a packet to calculate what the checksum value will be with a zero 2200 // value payload. We will then take that checksum value to construct another 2201 // packet which would result in the ones complement of the packet to be zero. 2202 var payload [2]byte 2203 { 2204 var r bytes.Reader 2205 r.Reset(payload[:]) 2206 n, err := c.EP.Write(&r, writeOpts) 2207 if err != nil { 2208 t.Fatalf("Write failed: %s", err) 2209 } 2210 if want := int64(len(payload)); n != want { 2211 t.Fatalf("got n = %d, want = %d", n, want) 2212 } 2213 2214 pkt := c.LinkEP.Read() 2215 if pkt == nil { 2216 t.Fatal("Packet wasn't written out") 2217 } 2218 2219 v := stack.PayloadSince(pkt.NetworkHeader()) 2220 defer v.Release() 2221 pkt.DecRef() 2222 checker.IPv6(t, v, checker.UDP()) 2223 2224 // Simply replacing the payload with the checksum value is enough to make 2225 // sure that we end up with an all ones value for the ones complement sum 2226 // because the checksum value is held the ones complement of the ones 2227 // complement sum. 2228 // 2229 // In ones complement arithmetic, adding a value A with a ones complement of 2230 // another value B is the same as subtracting B from A. 2231 // 2232 // The resulting ones complement will be C' = C - C so we know C' will be 2233 // zero. The stack should never send a zero value though so we expect all 2234 // ones below. 2235 binary.BigEndian.PutUint16(payload[:], header.UDP(header.IPv6(v.AsSlice()).Payload()).Checksum()) 2236 } 2237 2238 { 2239 var r bytes.Reader 2240 r.Reset(payload[:]) 2241 n, err := c.EP.Write(&r, writeOpts) 2242 if err != nil { 2243 t.Fatalf("Write failed: %s", err) 2244 } 2245 if want := int64(len(payload)); n != want { 2246 t.Fatalf("got n = %d, want = %d", n, want) 2247 } 2248 } 2249 2250 { 2251 pkt := c.LinkEP.Read() 2252 if pkt == nil { 2253 t.Fatal("Packet wasn't written out") 2254 } 2255 defer pkt.DecRef() 2256 2257 v := stack.PayloadSince(pkt.NetworkHeader()) 2258 defer v.Release() 2259 checker.IPv6(t, v, checker.UDP(checker.TransportChecksum(math.MaxUint16))) 2260 2261 // Make sure the all ones checksum is valid. 2262 hdr := header.IPv6(v.AsSlice()) 2263 udp := header.UDP(hdr.Payload()) 2264 if src, dst, payloadXsum := hdr.SourceAddress(), hdr.DestinationAddress(), checksum.Checksum(udp.Payload(), 0); !udp.IsChecksumValid(src, dst, payloadXsum) { 2265 t.Errorf("got udp.IsChecksumValid(%s, %s, %d) = false, want = true", src, dst, payloadXsum) 2266 } 2267 } 2268 } 2269 2270 // TestWritePayloadSizeTooBig verifies that writing anything bigger than 2271 // header.UDPMaximumPacketSize fails. 2272 func TestWritePayloadSizeTooBig(t *testing.T) { 2273 for _, writeOpSequence := range writeOpSequences { 2274 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}) 2275 2276 c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber) 2277 2278 if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil { 2279 c.T.Fatalf("Connect failed: %s", err) 2280 } 2281 2282 testWriteOpSequenceSucceeds(c, context.UnicastV6, writeOpSequence) 2283 2284 for _, writeOp := range writeOpSequence { 2285 switch writeOp { 2286 case write: 2287 testWriteFails(c, context.UnicastV6, header.UDPMaximumPacketSize+1, &tcpip.ErrMessageTooLong{}) 2288 case preflight: 2289 testPreflightSucceeds(c, context.UnicastV6) 2290 } 2291 } 2292 c.Cleanup() 2293 } 2294 } 2295 2296 func TestMain(m *testing.M) { 2297 refs.SetLeakMode(refs.LeaksPanic) 2298 code := m.Run() 2299 refs.DoLeakCheck() 2300 os.Exit(code) 2301 }