github.com/FlowerWrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/link/fdbased/endpoint_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // +build linux 16 17 package fdbased 18 19 import ( 20 "bytes" 21 "fmt" 22 "math/rand" 23 "reflect" 24 "syscall" 25 "testing" 26 "time" 27 "unsafe" 28 29 "github.com/FlowerWrong/netstack/tcpip" 30 "github.com/FlowerWrong/netstack/tcpip/buffer" 31 "github.com/FlowerWrong/netstack/tcpip/header" 32 "github.com/FlowerWrong/netstack/tcpip/link/rawfile" 33 "github.com/FlowerWrong/netstack/tcpip/stack" 34 ) 35 36 const ( 37 mtu = 1500 38 laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") 39 raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") 40 proto = 10 41 csumOffset = 48 42 gsoMSS = 500 43 ) 44 45 type packetInfo struct { 46 raddr tcpip.LinkAddress 47 proto tcpip.NetworkProtocolNumber 48 contents buffer.View 49 } 50 51 type context struct { 52 t *testing.T 53 fds [2]int 54 ep stack.LinkEndpoint 55 ch chan packetInfo 56 done chan struct{} 57 } 58 59 func newContext(t *testing.T, opt *Options) *context { 60 fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) 61 if err != nil { 62 t.Fatalf("Socketpair failed: %v", err) 63 } 64 65 done := make(chan struct{}, 1) 66 opt.ClosedFunc = func(*tcpip.Error) { 67 done <- struct{}{} 68 } 69 70 opt.FDs = []int{fds[1]} 71 ep, err := New(opt) 72 if err != nil { 73 t.Fatalf("Failed to create FD endpoint: %v", err) 74 } 75 76 c := &context{ 77 t: t, 78 fds: fds, 79 ep: ep, 80 ch: make(chan packetInfo, 100), 81 done: done, 82 } 83 84 ep.Attach(c) 85 86 return c 87 } 88 89 func (c *context) cleanup() { 90 syscall.Close(c.fds[0]) 91 <-c.done 92 syscall.Close(c.fds[1]) 93 } 94 95 func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { 96 c.ch <- packetInfo{remote, protocol, vv.ToView()} 97 } 98 99 func TestNoEthernetProperties(t *testing.T) { 100 c := newContext(t, &Options{MTU: mtu}) 101 defer c.cleanup() 102 103 if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v { 104 t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) 105 } 106 107 if want, v := uint32(mtu), c.ep.MTU(); want != v { 108 t.Fatalf("MTU() = %v, want %v", v, want) 109 } 110 } 111 112 func TestEthernetProperties(t *testing.T) { 113 c := newContext(t, &Options{EthernetHeader: true, MTU: mtu}) 114 defer c.cleanup() 115 116 if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v { 117 t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) 118 } 119 120 if want, v := uint32(mtu), c.ep.MTU(); want != v { 121 t.Fatalf("MTU() = %v, want %v", v, want) 122 } 123 } 124 125 func TestAddress(t *testing.T) { 126 addrs := []tcpip.LinkAddress{"", "abc", "def"} 127 for _, a := range addrs { 128 t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) { 129 c := newContext(t, &Options{Address: a, MTU: mtu}) 130 defer c.cleanup() 131 132 if want, v := a, c.ep.LinkAddress(); want != v { 133 t.Fatalf("LinkAddress() = %v, want %v", v, want) 134 } 135 }) 136 } 137 } 138 139 func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { 140 c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) 141 defer c.cleanup() 142 143 r := &stack.Route{ 144 RemoteLinkAddress: raddr, 145 } 146 147 // Build header. 148 hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) 149 b := hdr.Prepend(100) 150 for i := range b { 151 b[i] = uint8(rand.Intn(256)) 152 } 153 154 // Build payload and write. 155 payload := make(buffer.View, plen) 156 for i := range payload { 157 payload[i] = uint8(rand.Intn(256)) 158 } 159 want := append(hdr.View(), payload...) 160 var gso *stack.GSO 161 if gsoMaxSize != 0 { 162 gso = &stack.GSO{ 163 Type: stack.GSOTCPv6, 164 NeedsCsum: true, 165 CsumOffset: csumOffset, 166 MSS: gsoMSS, 167 MaxSize: gsoMaxSize, 168 L3HdrLen: header.IPv4MaximumHeaderSize, 169 } 170 } 171 if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil { 172 t.Fatalf("WritePacket failed: %v", err) 173 } 174 175 // Read from fd, then compare with what we wrote. 176 b = make([]byte, mtu) 177 n, err := syscall.Read(c.fds[0], b) 178 if err != nil { 179 t.Fatalf("Read failed: %v", err) 180 } 181 b = b[:n] 182 if gsoMaxSize != 0 { 183 vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0])) 184 if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 { 185 t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM) 186 } 187 csumStart := header.EthernetMinimumSize + gso.L3HdrLen 188 if vnetHdr.csumStart != csumStart { 189 t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart) 190 } 191 if vnetHdr.csumOffset != csumOffset { 192 t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset) 193 } 194 gsoType := uint8(0) 195 if int(gso.MSS) < plen { 196 gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 197 } 198 if vnetHdr.gsoType != gsoType { 199 t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType) 200 } 201 b = b[virtioNetHdrSize:] 202 } 203 if eth { 204 h := header.Ethernet(b) 205 b = b[header.EthernetMinimumSize:] 206 207 if a := h.SourceAddress(); a != laddr { 208 t.Fatalf("SourceAddress() = %v, want %v", a, laddr) 209 } 210 211 if a := h.DestinationAddress(); a != raddr { 212 t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) 213 } 214 215 if et := h.Type(); et != proto { 216 t.Fatalf("Type() = %v, want %v", et, proto) 217 } 218 } 219 if len(b) != len(want) { 220 t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) 221 } 222 if !bytes.Equal(b, want) { 223 t.Fatalf("Read returned %x, want %x", b, want) 224 } 225 } 226 227 func TestWritePacket(t *testing.T) { 228 lengths := []int{0, 100, 1000} 229 eths := []bool{true, false} 230 gsos := []uint32{0, 32768} 231 232 for _, eth := range eths { 233 for _, plen := range lengths { 234 for _, gso := range gsos { 235 t.Run( 236 fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), 237 func(t *testing.T) { 238 testWritePacket(t, plen, eth, gso) 239 }, 240 ) 241 } 242 } 243 } 244 } 245 246 func TestPreserveSrcAddress(t *testing.T) { 247 baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") 248 249 c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true}) 250 defer c.cleanup() 251 252 // Set LocalLinkAddress in route to the value of the bridged address. 253 r := &stack.Route{ 254 RemoteLinkAddress: raddr, 255 LocalLinkAddress: baddr, 256 } 257 258 // WritePacket panics given a prependable with anything less than 259 // the minimum size of the ethernet header. 260 hdr := buffer.NewPrependable(header.EthernetMinimumSize) 261 if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { 262 t.Fatalf("WritePacket failed: %v", err) 263 } 264 265 // Read from the FD, then compare with what we wrote. 266 b := make([]byte, mtu) 267 n, err := syscall.Read(c.fds[0], b) 268 if err != nil { 269 t.Fatalf("Read failed: %v", err) 270 } 271 b = b[:n] 272 h := header.Ethernet(b) 273 274 if a := h.SourceAddress(); a != baddr { 275 t.Fatalf("SourceAddress() = %v, want %v", a, baddr) 276 } 277 } 278 279 func TestDeliverPacket(t *testing.T) { 280 lengths := []int{100, 1000} 281 eths := []bool{true, false} 282 283 for _, eth := range eths { 284 for _, plen := range lengths { 285 t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { 286 c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) 287 defer c.cleanup() 288 289 // Build packet. 290 b := make([]byte, plen) 291 all := b 292 for i := range b { 293 b[i] = uint8(rand.Intn(256)) 294 } 295 296 if !eth { 297 // So that it looks like an IPv4 packet. 298 b[0] = 0x40 299 } else { 300 hdr := make(header.Ethernet, header.EthernetMinimumSize) 301 hdr.Encode(&header.EthernetFields{ 302 SrcAddr: raddr, 303 DstAddr: laddr, 304 Type: proto, 305 }) 306 all = append(hdr, b...) 307 } 308 309 // Write packet via the file descriptor. 310 if _, err := syscall.Write(c.fds[0], all); err != nil { 311 t.Fatalf("Write failed: %v", err) 312 } 313 314 // Receive packet through the endpoint. 315 select { 316 case pi := <-c.ch: 317 want := packetInfo{ 318 raddr: raddr, 319 proto: proto, 320 contents: b, 321 } 322 if !eth { 323 want.proto = header.IPv4ProtocolNumber 324 want.raddr = "" 325 } 326 if !reflect.DeepEqual(want, pi) { 327 t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) 328 } 329 case <-time.After(10 * time.Second): 330 t.Fatalf("Timed out waiting for packet") 331 } 332 }) 333 } 334 } 335 } 336 337 func TestBufConfigMaxLength(t *testing.T) { 338 got := 0 339 for _, i := range BufConfig { 340 got += i 341 } 342 want := header.MaxIPPacketSize // maximum TCP packet size 343 if got < want { 344 t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want) 345 } 346 } 347 348 func TestBufConfigFirst(t *testing.T) { 349 // The stack assumes that the TCP/IP header is enterily contained in the first view. 350 // Therefore, the first view needs to be large enough to contain the maximum TCP/IP 351 // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP). 352 want := 120 353 got := BufConfig[0] 354 if got < want { 355 t.Errorf("first view has an invalid size: got %d, want >= %d", got, want) 356 } 357 } 358 359 var capLengthTestCases = []struct { 360 comment string 361 config []int 362 n int 363 wantUsed int 364 wantLengths []int 365 }{ 366 { 367 comment: "Single slice", 368 config: []int{2}, 369 n: 1, 370 wantUsed: 1, 371 wantLengths: []int{1}, 372 }, 373 { 374 comment: "Multiple slices", 375 config: []int{1, 2}, 376 n: 2, 377 wantUsed: 2, 378 wantLengths: []int{1, 1}, 379 }, 380 { 381 comment: "Entire buffer", 382 config: []int{1, 2}, 383 n: 3, 384 wantUsed: 2, 385 wantLengths: []int{1, 2}, 386 }, 387 { 388 comment: "Entire buffer but not on the last slice", 389 config: []int{1, 2, 3}, 390 n: 3, 391 wantUsed: 2, 392 wantLengths: []int{1, 2, 3}, 393 }, 394 } 395 396 func TestReadVDispatcherCapLength(t *testing.T) { 397 for _, c := range capLengthTestCases { 398 // fd does not matter for this test. 399 d := readVDispatcher{fd: -1, e: &endpoint{}} 400 d.views = make([]buffer.View, len(c.config)) 401 d.iovecs = make([]syscall.Iovec, len(c.config)) 402 d.allocateViews(c.config) 403 404 used := d.capViews(c.n, c.config) 405 if used != c.wantUsed { 406 t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) 407 } 408 lengths := make([]int, len(d.views)) 409 for i, v := range d.views { 410 lengths[i] = len(v) 411 } 412 if !reflect.DeepEqual(lengths, c.wantLengths) { 413 t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) 414 } 415 } 416 } 417 418 func TestRecvMMsgDispatcherCapLength(t *testing.T) { 419 for _, c := range capLengthTestCases { 420 d := recvMMsgDispatcher{ 421 fd: -1, // fd does not matter for this test. 422 e: &endpoint{}, 423 views: make([][]buffer.View, 1), 424 iovecs: make([][]syscall.Iovec, 1), 425 msgHdrs: make([]rawfile.MMsgHdr, 1), 426 } 427 428 for i, _ := range d.views { 429 d.views[i] = make([]buffer.View, len(c.config)) 430 } 431 for i := range d.iovecs { 432 d.iovecs[i] = make([]syscall.Iovec, len(c.config)) 433 } 434 for k, msgHdr := range d.msgHdrs { 435 msgHdr.Msg.Iov = &d.iovecs[k][0] 436 msgHdr.Msg.Iovlen = uint64(len(c.config)) 437 } 438 439 d.allocateViews(c.config) 440 441 used := d.capViews(0, c.n, c.config) 442 if used != c.wantUsed { 443 t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) 444 } 445 lengths := make([]int, len(d.views[0])) 446 for i, v := range d.views[0] { 447 lengths[i] = len(v) 448 } 449 if !reflect.DeepEqual(lengths, c.wantLengths) { 450 t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) 451 } 452 453 } 454 }