github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/link/fdbased/endpoint_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // +build linux 16 17 package fdbased 18 19 import ( 20 "bytes" 21 "fmt" 22 "math/rand" 23 "reflect" 24 "testing" 25 "time" 26 "unsafe" 27 28 "github.com/google/go-cmp/cmp" 29 "golang.org/x/sys/unix" 30 "github.com/SagerNet/gvisor/pkg/tcpip" 31 "github.com/SagerNet/gvisor/pkg/tcpip/buffer" 32 "github.com/SagerNet/gvisor/pkg/tcpip/header" 33 "github.com/SagerNet/gvisor/pkg/tcpip/stack" 34 ) 35 36 const ( 37 mtu = 1500 38 laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") 39 raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") 40 proto = 10 41 csumOffset = 48 42 gsoMSS = 500 43 ) 44 45 type packetInfo struct { 46 Raddr tcpip.LinkAddress 47 Proto tcpip.NetworkProtocolNumber 48 Contents *stack.PacketBuffer 49 } 50 51 type packetContents struct { 52 LinkHeader buffer.View 53 NetworkHeader buffer.View 54 TransportHeader buffer.View 55 Data buffer.View 56 } 57 58 func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { 59 t.Helper() 60 if diff := cmp.Diff( 61 want, got, 62 cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents { 63 if pk == nil { 64 return nil 65 } 66 return &packetContents{ 67 LinkHeader: pk.LinkHeader().View(), 68 NetworkHeader: pk.NetworkHeader().View(), 69 TransportHeader: pk.TransportHeader().View(), 70 Data: pk.Data().AsRange().ToOwnedView(), 71 } 72 }), 73 ); diff != "" { 74 t.Errorf("unexpected packetInfo (-want +got):\n%s", diff) 75 } 76 } 77 78 type context struct { 79 t *testing.T 80 readFDs []int 81 writeFDs []int 82 ep stack.LinkEndpoint 83 ch chan packetInfo 84 done chan struct{} 85 } 86 87 func newContext(t *testing.T, opt *Options) *context { 88 firstFDPair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_SEQPACKET, 0) 89 if err != nil { 90 t.Fatalf("Socketpair failed: %v", err) 91 } 92 secondFDPair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_SEQPACKET, 0) 93 if err != nil { 94 t.Fatalf("Socketpair failed: %v", err) 95 } 96 97 done := make(chan struct{}, 2) 98 opt.ClosedFunc = func(tcpip.Error) { 99 done <- struct{}{} 100 } 101 102 opt.FDs = []int{firstFDPair[1], secondFDPair[1]} 103 ep, err := New(opt) 104 if err != nil { 105 t.Fatalf("Failed to create FD endpoint: %v", err) 106 } 107 108 c := &context{ 109 t: t, 110 readFDs: []int{firstFDPair[0], secondFDPair[0]}, 111 writeFDs: opt.FDs, 112 ep: ep, 113 ch: make(chan packetInfo, 100), 114 done: done, 115 } 116 117 ep.Attach(c) 118 119 return c 120 } 121 122 func (c *context) cleanup() { 123 for _, fd := range c.readFDs { 124 unix.Close(fd) 125 } 126 <-c.done 127 <-c.done 128 for _, fd := range c.writeFDs { 129 unix.Close(fd) 130 } 131 } 132 133 func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { 134 c.ch <- packetInfo{remote, protocol, pkt} 135 } 136 137 func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { 138 panic("unimplemented") 139 } 140 141 func TestNoEthernetProperties(t *testing.T) { 142 c := newContext(t, &Options{MTU: mtu}) 143 defer c.cleanup() 144 145 if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v { 146 t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) 147 } 148 149 if want, v := uint32(mtu), c.ep.MTU(); want != v { 150 t.Fatalf("MTU() = %v, want %v", v, want) 151 } 152 } 153 154 func TestEthernetProperties(t *testing.T) { 155 c := newContext(t, &Options{EthernetHeader: true, MTU: mtu}) 156 defer c.cleanup() 157 158 if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v { 159 t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) 160 } 161 162 if want, v := uint32(mtu), c.ep.MTU(); want != v { 163 t.Fatalf("MTU() = %v, want %v", v, want) 164 } 165 } 166 167 func TestAddress(t *testing.T) { 168 addrs := []tcpip.LinkAddress{"", "abc", "def"} 169 for _, a := range addrs { 170 t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) { 171 c := newContext(t, &Options{Address: a, MTU: mtu}) 172 defer c.cleanup() 173 174 if want, v := a, c.ep.LinkAddress(); want != v { 175 t.Fatalf("LinkAddress() = %v, want %v", v, want) 176 } 177 }) 178 } 179 } 180 181 func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) { 182 c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) 183 defer c.cleanup() 184 185 var r stack.RouteInfo 186 r.RemoteLinkAddress = raddr 187 188 // Build payload. 189 payload := buffer.NewView(plen) 190 if _, err := rand.Read(payload); err != nil { 191 t.Fatalf("rand.Read(payload): %s", err) 192 } 193 194 // Build packet buffer. 195 const netHdrLen = 100 196 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 197 ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen, 198 Data: payload.ToVectorisedView(), 199 }) 200 pkt.Hash = hash 201 202 // Build header. 203 b := pkt.NetworkHeader().Push(netHdrLen) 204 if _, err := rand.Read(b); err != nil { 205 t.Fatalf("rand.Read(b): %s", err) 206 } 207 208 // Write. 209 want := append(append(buffer.View(nil), b...), payload...) 210 const l3HdrLen = header.IPv6MinimumSize 211 if gsoMaxSize != 0 { 212 pkt.GSOOptions = stack.GSO{ 213 Type: stack.GSOTCPv6, 214 NeedsCsum: true, 215 CsumOffset: csumOffset, 216 MSS: gsoMSS, 217 L3HdrLen: l3HdrLen, 218 } 219 } 220 if err := c.ep.WritePacket(r, proto, pkt); err != nil { 221 t.Fatalf("WritePacket failed: %v", err) 222 } 223 224 // Read from the corresponding FD, then compare with what we wrote. 225 b = make([]byte, mtu) 226 fd := c.readFDs[hash%uint32(len(c.readFDs))] 227 n, err := unix.Read(fd, b) 228 if err != nil { 229 t.Fatalf("Read failed: %v", err) 230 } 231 b = b[:n] 232 if gsoMaxSize != 0 { 233 vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0])) 234 if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 { 235 t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM) 236 } 237 const csumStart = header.EthernetMinimumSize + l3HdrLen 238 if vnetHdr.csumStart != csumStart { 239 t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart) 240 } 241 if vnetHdr.csumOffset != csumOffset { 242 t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset) 243 } 244 gsoType := uint8(0) 245 if plen > gsoMSS { 246 gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 247 } 248 if vnetHdr.gsoType != gsoType { 249 t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType) 250 } 251 b = b[virtioNetHdrSize:] 252 } 253 if eth { 254 h := header.Ethernet(b) 255 b = b[header.EthernetMinimumSize:] 256 257 if a := h.SourceAddress(); a != laddr { 258 t.Fatalf("SourceAddress() = %v, want %v", a, laddr) 259 } 260 261 if a := h.DestinationAddress(); a != raddr { 262 t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) 263 } 264 265 if et := h.Type(); et != proto { 266 t.Fatalf("Type() = %v, want %v", et, proto) 267 } 268 } 269 if len(b) != len(want) { 270 t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) 271 } 272 if !bytes.Equal(b, want) { 273 t.Fatalf("Read returned %x, want %x", b, want) 274 } 275 } 276 277 func TestWritePacket(t *testing.T) { 278 lengths := []int{0, 100, 1000} 279 eths := []bool{true, false} 280 gsos := []uint32{0, 32768} 281 282 for _, eth := range eths { 283 for _, plen := range lengths { 284 for _, gso := range gsos { 285 t.Run( 286 fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), 287 func(t *testing.T) { 288 testWritePacket(t, plen, eth, gso, 0) 289 }, 290 ) 291 } 292 } 293 } 294 } 295 296 func TestHashedWritePacket(t *testing.T) { 297 lengths := []int{0, 100, 1000} 298 eths := []bool{true, false} 299 gsos := []uint32{0, 32768} 300 hashes := []uint32{0, 1} 301 for _, eth := range eths { 302 for _, plen := range lengths { 303 for _, gso := range gsos { 304 for _, hash := range hashes { 305 t.Run( 306 fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash), 307 func(t *testing.T) { 308 testWritePacket(t, plen, eth, gso, hash) 309 }, 310 ) 311 } 312 } 313 } 314 } 315 } 316 317 func TestPreserveSrcAddress(t *testing.T) { 318 baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") 319 320 c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true}) 321 defer c.cleanup() 322 323 // Set LocalLinkAddress in route to the value of the bridged address. 324 var r stack.RouteInfo 325 r.LocalLinkAddress = baddr 326 r.RemoteLinkAddress = raddr 327 328 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 329 // WritePacket panics given a prependable with anything less than 330 // the minimum size of the ethernet header. 331 // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength(). 332 ReserveHeaderBytes: header.EthernetMinimumSize, 333 Data: buffer.VectorisedView{}, 334 }) 335 if err := c.ep.WritePacket(r, proto, pkt); err != nil { 336 t.Fatalf("WritePacket failed: %v", err) 337 } 338 339 // Read from the FD, then compare with what we wrote. 340 b := make([]byte, mtu) 341 n, err := unix.Read(c.readFDs[0], b) 342 if err != nil { 343 t.Fatalf("Read failed: %v", err) 344 } 345 b = b[:n] 346 h := header.Ethernet(b) 347 348 if a := h.SourceAddress(); a != baddr { 349 t.Fatalf("SourceAddress() = %v, want %v", a, baddr) 350 } 351 } 352 353 func TestDeliverPacket(t *testing.T) { 354 lengths := []int{100, 1000} 355 eths := []bool{true, false} 356 357 for _, eth := range eths { 358 for _, plen := range lengths { 359 t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { 360 c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) 361 defer c.cleanup() 362 363 // Build packet. 364 all := make([]byte, plen) 365 if _, err := rand.Read(all); err != nil { 366 t.Fatalf("rand.Read(all): %s", err) 367 } 368 // Make it look like an IPv4 packet. 369 all[0] = 0x40 370 371 wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 372 ReserveHeaderBytes: header.EthernetMinimumSize, 373 Data: buffer.NewViewFromBytes(all).ToVectorisedView(), 374 }) 375 if eth { 376 hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize)) 377 hdr.Encode(&header.EthernetFields{ 378 SrcAddr: raddr, 379 DstAddr: laddr, 380 Type: proto, 381 }) 382 all = append(hdr, all...) 383 } 384 385 // Write packet via the file descriptor. 386 if _, err := unix.Write(c.readFDs[0], all); err != nil { 387 t.Fatalf("Write failed: %v", err) 388 } 389 390 // Receive packet through the endpoint. 391 select { 392 case pi := <-c.ch: 393 want := packetInfo{ 394 Raddr: raddr, 395 Proto: proto, 396 Contents: wantPkt, 397 } 398 if !eth { 399 want.Proto = header.IPv4ProtocolNumber 400 want.Raddr = "" 401 } 402 checkPacketInfoEqual(t, pi, want) 403 case <-time.After(10 * time.Second): 404 t.Fatalf("Timed out waiting for packet") 405 } 406 }) 407 } 408 } 409 } 410 411 func TestBufConfigMaxLength(t *testing.T) { 412 got := 0 413 for _, i := range BufConfig { 414 got += i 415 } 416 want := header.MaxIPPacketSize // maximum TCP packet size 417 if got < want { 418 t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want) 419 } 420 } 421 422 func TestBufConfigFirst(t *testing.T) { 423 // The stack assumes that the TCP/IP header is enterily contained in the first view. 424 // Therefore, the first view needs to be large enough to contain the maximum TCP/IP 425 // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP). 426 want := 120 427 got := BufConfig[0] 428 if got < want { 429 t.Errorf("first view has an invalid size: got %d, want >= %d", got, want) 430 } 431 } 432 433 var capLengthTestCases = []struct { 434 comment string 435 config []int 436 n int 437 wantUsed int 438 wantLengths []int 439 }{ 440 { 441 comment: "Single slice", 442 config: []int{2}, 443 n: 1, 444 wantUsed: 1, 445 wantLengths: []int{1}, 446 }, 447 { 448 comment: "Multiple slices", 449 config: []int{1, 2}, 450 n: 2, 451 wantUsed: 2, 452 wantLengths: []int{1, 1}, 453 }, 454 { 455 comment: "Entire buffer", 456 config: []int{1, 2}, 457 n: 3, 458 wantUsed: 2, 459 wantLengths: []int{1, 2}, 460 }, 461 { 462 comment: "Entire buffer but not on the last slice", 463 config: []int{1, 2, 3}, 464 n: 3, 465 wantUsed: 2, 466 wantLengths: []int{1, 2}, 467 }, 468 } 469 470 func TestIovecBuffer(t *testing.T) { 471 for _, c := range capLengthTestCases { 472 t.Run(c.comment, func(t *testing.T) { 473 b := newIovecBuffer(c.config, false /* skipsVnetHdr */) 474 475 // Test initial allocation. 476 iovecs := b.nextIovecs() 477 if got, want := len(iovecs), len(c.config); got != want { 478 t.Fatalf("len(iovecs) = %d, want %d", got, want) 479 } 480 481 // Make a copy as iovecs points to internal slice. We will need this state 482 // later. 483 oldIovecs := append([]unix.Iovec(nil), iovecs...) 484 485 // Test the views that get pulled. 486 vv := b.pullViews(c.n) 487 var lengths []int 488 for _, v := range vv.Views() { 489 lengths = append(lengths, len(v)) 490 } 491 if !reflect.DeepEqual(lengths, c.wantLengths) { 492 t.Errorf("Pulled view lengths = %v, want %v", lengths, c.wantLengths) 493 } 494 495 // Test that new views get reallocated. 496 for i, newIov := range b.nextIovecs() { 497 if i < c.wantUsed { 498 if newIov.Base == oldIovecs[i].Base { 499 t.Errorf("b.views[%d] should have been reallocated", i) 500 } 501 } else { 502 if newIov.Base != oldIovecs[i].Base { 503 t.Errorf("b.views[%d] should not have been reallocated", i) 504 } 505 } 506 } 507 }) 508 } 509 } 510 511 func TestIovecBufferSkipVnetHdr(t *testing.T) { 512 for _, test := range []struct { 513 desc string 514 readN int 515 wantLen int 516 }{ 517 { 518 desc: "nothing read", 519 readN: 0, 520 wantLen: 0, 521 }, 522 { 523 desc: "smaller than vnet header", 524 readN: virtioNetHdrSize - 1, 525 wantLen: 0, 526 }, 527 { 528 desc: "header skipped", 529 readN: virtioNetHdrSize + 100, 530 wantLen: 100, 531 }, 532 } { 533 t.Run(test.desc, func(t *testing.T) { 534 b := newIovecBuffer([]int{10, 20, 50, 50}, true) 535 // Pretend a read happend. 536 b.nextIovecs() 537 vv := b.pullViews(test.readN) 538 if got, want := vv.Size(), test.wantLen; got != want { 539 t.Errorf("b.pullView(%d).Size() = %d; want %d", test.readN, got, want) 540 } 541 if got, want := len(vv.ToOwnedView()), test.wantLen; got != want { 542 t.Errorf("b.pullView(%d).ToOwnedView() has length %d; want %d", test.readN, got, want) 543 } 544 }) 545 } 546 } 547 548 // fakeNetworkDispatcher delivers packets to pkts. 549 type fakeNetworkDispatcher struct { 550 pkts []*stack.PacketBuffer 551 } 552 553 func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { 554 d.pkts = append(d.pkts, pkt) 555 } 556 557 func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { 558 panic("unimplemented") 559 } 560 561 func TestDispatchPacketFormat(t *testing.T) { 562 for _, test := range []struct { 563 name string 564 newDispatcher func(fd int, e *endpoint) (linkDispatcher, error) 565 }{ 566 { 567 name: "readVDispatcher", 568 newDispatcher: newReadVDispatcher, 569 }, 570 { 571 name: "recvMMsgDispatcher", 572 newDispatcher: newRecvMMsgDispatcher, 573 }, 574 } { 575 t.Run(test.name, func(t *testing.T) { 576 // Create a socket pair to send/recv. 577 fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_DGRAM, 0) 578 if err != nil { 579 t.Fatal(err) 580 } 581 defer unix.Close(fds[0]) 582 defer unix.Close(fds[1]) 583 584 data := []byte{ 585 // Ethernet header. 586 1, 2, 3, 4, 5, 60, 587 1, 2, 3, 4, 5, 61, 588 8, 0, 589 // Mock network header. 590 40, 41, 42, 43, 591 } 592 err = unix.Sendmsg(fds[1], data, nil, nil, 0) 593 if err != nil { 594 t.Fatal(err) 595 } 596 597 // Create and run dispatcher once. 598 sink := &fakeNetworkDispatcher{} 599 d, err := test.newDispatcher(fds[0], &endpoint{ 600 hdrSize: header.EthernetMinimumSize, 601 dispatcher: sink, 602 }) 603 if err != nil { 604 t.Fatal(err) 605 } 606 if ok, err := d.dispatch(); !ok || err != nil { 607 t.Fatalf("d.dispatch() = %v, %v", ok, err) 608 } 609 610 // Verify packet. 611 if got, want := len(sink.pkts), 1; got != want { 612 t.Fatalf("len(sink.pkts) = %d, want %d", got, want) 613 } 614 pkt := sink.pkts[0] 615 if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { 616 t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) 617 } 618 if got, want := pkt.Data().Size(), 4; got != want { 619 t.Errorf("pkt.Data().Size() = %d, want %d", got, want) 620 } 621 }) 622 } 623 }