github.com/FlowerWrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/link/fdbased/endpoint_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // +build linux
    16  
    17  package fdbased
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"math/rand"
    23  	"reflect"
    24  	"syscall"
    25  	"testing"
    26  	"time"
    27  	"unsafe"
    28  
    29  	"github.com/FlowerWrong/netstack/tcpip"
    30  	"github.com/FlowerWrong/netstack/tcpip/buffer"
    31  	"github.com/FlowerWrong/netstack/tcpip/header"
    32  	"github.com/FlowerWrong/netstack/tcpip/link/rawfile"
    33  	"github.com/FlowerWrong/netstack/tcpip/stack"
    34  )
    35  
    36  const (
    37  	mtu        = 1500
    38  	laddr      = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
    39  	raddr      = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
    40  	proto      = 10
    41  	csumOffset = 48
    42  	gsoMSS     = 500
    43  )
    44  
    45  type packetInfo struct {
    46  	raddr    tcpip.LinkAddress
    47  	proto    tcpip.NetworkProtocolNumber
    48  	contents buffer.View
    49  }
    50  
    51  type context struct {
    52  	t    *testing.T
    53  	fds  [2]int
    54  	ep   stack.LinkEndpoint
    55  	ch   chan packetInfo
    56  	done chan struct{}
    57  }
    58  
    59  func newContext(t *testing.T, opt *Options) *context {
    60  	fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
    61  	if err != nil {
    62  		t.Fatalf("Socketpair failed: %v", err)
    63  	}
    64  
    65  	done := make(chan struct{}, 1)
    66  	opt.ClosedFunc = func(*tcpip.Error) {
    67  		done <- struct{}{}
    68  	}
    69  
    70  	opt.FDs = []int{fds[1]}
    71  	ep, err := New(opt)
    72  	if err != nil {
    73  		t.Fatalf("Failed to create FD endpoint: %v", err)
    74  	}
    75  
    76  	c := &context{
    77  		t:    t,
    78  		fds:  fds,
    79  		ep:   ep,
    80  		ch:   make(chan packetInfo, 100),
    81  		done: done,
    82  	}
    83  
    84  	ep.Attach(c)
    85  
    86  	return c
    87  }
    88  
    89  func (c *context) cleanup() {
    90  	syscall.Close(c.fds[0])
    91  	<-c.done
    92  	syscall.Close(c.fds[1])
    93  }
    94  
    95  func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
    96  	c.ch <- packetInfo{remote, protocol, vv.ToView()}
    97  }
    98  
    99  func TestNoEthernetProperties(t *testing.T) {
   100  	c := newContext(t, &Options{MTU: mtu})
   101  	defer c.cleanup()
   102  
   103  	if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v {
   104  		t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
   105  	}
   106  
   107  	if want, v := uint32(mtu), c.ep.MTU(); want != v {
   108  		t.Fatalf("MTU() = %v, want %v", v, want)
   109  	}
   110  }
   111  
   112  func TestEthernetProperties(t *testing.T) {
   113  	c := newContext(t, &Options{EthernetHeader: true, MTU: mtu})
   114  	defer c.cleanup()
   115  
   116  	if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v {
   117  		t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
   118  	}
   119  
   120  	if want, v := uint32(mtu), c.ep.MTU(); want != v {
   121  		t.Fatalf("MTU() = %v, want %v", v, want)
   122  	}
   123  }
   124  
   125  func TestAddress(t *testing.T) {
   126  	addrs := []tcpip.LinkAddress{"", "abc", "def"}
   127  	for _, a := range addrs {
   128  		t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) {
   129  			c := newContext(t, &Options{Address: a, MTU: mtu})
   130  			defer c.cleanup()
   131  
   132  			if want, v := a, c.ep.LinkAddress(); want != v {
   133  				t.Fatalf("LinkAddress() = %v, want %v", v, want)
   134  			}
   135  		})
   136  	}
   137  }
   138  
   139  func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) {
   140  	c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
   141  	defer c.cleanup()
   142  
   143  	r := &stack.Route{
   144  		RemoteLinkAddress: raddr,
   145  	}
   146  
   147  	// Build header.
   148  	hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
   149  	b := hdr.Prepend(100)
   150  	for i := range b {
   151  		b[i] = uint8(rand.Intn(256))
   152  	}
   153  
   154  	// Build payload and write.
   155  	payload := make(buffer.View, plen)
   156  	for i := range payload {
   157  		payload[i] = uint8(rand.Intn(256))
   158  	}
   159  	want := append(hdr.View(), payload...)
   160  	var gso *stack.GSO
   161  	if gsoMaxSize != 0 {
   162  		gso = &stack.GSO{
   163  			Type:       stack.GSOTCPv6,
   164  			NeedsCsum:  true,
   165  			CsumOffset: csumOffset,
   166  			MSS:        gsoMSS,
   167  			MaxSize:    gsoMaxSize,
   168  			L3HdrLen:   header.IPv4MaximumHeaderSize,
   169  		}
   170  	}
   171  	if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil {
   172  		t.Fatalf("WritePacket failed: %v", err)
   173  	}
   174  
   175  	// Read from fd, then compare with what we wrote.
   176  	b = make([]byte, mtu)
   177  	n, err := syscall.Read(c.fds[0], b)
   178  	if err != nil {
   179  		t.Fatalf("Read failed: %v", err)
   180  	}
   181  	b = b[:n]
   182  	if gsoMaxSize != 0 {
   183  		vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0]))
   184  		if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 {
   185  			t.Fatalf("virtioNetHdr.flags %v  doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM)
   186  		}
   187  		csumStart := header.EthernetMinimumSize + gso.L3HdrLen
   188  		if vnetHdr.csumStart != csumStart {
   189  			t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart)
   190  		}
   191  		if vnetHdr.csumOffset != csumOffset {
   192  			t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset)
   193  		}
   194  		gsoType := uint8(0)
   195  		if int(gso.MSS) < plen {
   196  			gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
   197  		}
   198  		if vnetHdr.gsoType != gsoType {
   199  			t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType)
   200  		}
   201  		b = b[virtioNetHdrSize:]
   202  	}
   203  	if eth {
   204  		h := header.Ethernet(b)
   205  		b = b[header.EthernetMinimumSize:]
   206  
   207  		if a := h.SourceAddress(); a != laddr {
   208  			t.Fatalf("SourceAddress() = %v, want %v", a, laddr)
   209  		}
   210  
   211  		if a := h.DestinationAddress(); a != raddr {
   212  			t.Fatalf("DestinationAddress() = %v, want %v", a, raddr)
   213  		}
   214  
   215  		if et := h.Type(); et != proto {
   216  			t.Fatalf("Type() = %v, want %v", et, proto)
   217  		}
   218  	}
   219  	if len(b) != len(want) {
   220  		t.Fatalf("Read returned %v bytes, want %v", len(b), len(want))
   221  	}
   222  	if !bytes.Equal(b, want) {
   223  		t.Fatalf("Read returned %x, want %x", b, want)
   224  	}
   225  }
   226  
   227  func TestWritePacket(t *testing.T) {
   228  	lengths := []int{0, 100, 1000}
   229  	eths := []bool{true, false}
   230  	gsos := []uint32{0, 32768}
   231  
   232  	for _, eth := range eths {
   233  		for _, plen := range lengths {
   234  			for _, gso := range gsos {
   235  				t.Run(
   236  					fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso),
   237  					func(t *testing.T) {
   238  						testWritePacket(t, plen, eth, gso)
   239  					},
   240  				)
   241  			}
   242  		}
   243  	}
   244  }
   245  
   246  func TestPreserveSrcAddress(t *testing.T) {
   247  	baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99")
   248  
   249  	c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true})
   250  	defer c.cleanup()
   251  
   252  	// Set LocalLinkAddress in route to the value of the bridged address.
   253  	r := &stack.Route{
   254  		RemoteLinkAddress: raddr,
   255  		LocalLinkAddress:  baddr,
   256  	}
   257  
   258  	// WritePacket panics given a prependable with anything less than
   259  	// the minimum size of the ethernet header.
   260  	hdr := buffer.NewPrependable(header.EthernetMinimumSize)
   261  	if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil {
   262  		t.Fatalf("WritePacket failed: %v", err)
   263  	}
   264  
   265  	// Read from the FD, then compare with what we wrote.
   266  	b := make([]byte, mtu)
   267  	n, err := syscall.Read(c.fds[0], b)
   268  	if err != nil {
   269  		t.Fatalf("Read failed: %v", err)
   270  	}
   271  	b = b[:n]
   272  	h := header.Ethernet(b)
   273  
   274  	if a := h.SourceAddress(); a != baddr {
   275  		t.Fatalf("SourceAddress() = %v, want %v", a, baddr)
   276  	}
   277  }
   278  
   279  func TestDeliverPacket(t *testing.T) {
   280  	lengths := []int{100, 1000}
   281  	eths := []bool{true, false}
   282  
   283  	for _, eth := range eths {
   284  		for _, plen := range lengths {
   285  			t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
   286  				c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
   287  				defer c.cleanup()
   288  
   289  				// Build packet.
   290  				b := make([]byte, plen)
   291  				all := b
   292  				for i := range b {
   293  					b[i] = uint8(rand.Intn(256))
   294  				}
   295  
   296  				if !eth {
   297  					// So that it looks like an IPv4 packet.
   298  					b[0] = 0x40
   299  				} else {
   300  					hdr := make(header.Ethernet, header.EthernetMinimumSize)
   301  					hdr.Encode(&header.EthernetFields{
   302  						SrcAddr: raddr,
   303  						DstAddr: laddr,
   304  						Type:    proto,
   305  					})
   306  					all = append(hdr, b...)
   307  				}
   308  
   309  				// Write packet via the file descriptor.
   310  				if _, err := syscall.Write(c.fds[0], all); err != nil {
   311  					t.Fatalf("Write failed: %v", err)
   312  				}
   313  
   314  				// Receive packet through the endpoint.
   315  				select {
   316  				case pi := <-c.ch:
   317  					want := packetInfo{
   318  						raddr:    raddr,
   319  						proto:    proto,
   320  						contents: b,
   321  					}
   322  					if !eth {
   323  						want.proto = header.IPv4ProtocolNumber
   324  						want.raddr = ""
   325  					}
   326  					if !reflect.DeepEqual(want, pi) {
   327  						t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
   328  					}
   329  				case <-time.After(10 * time.Second):
   330  					t.Fatalf("Timed out waiting for packet")
   331  				}
   332  			})
   333  		}
   334  	}
   335  }
   336  
   337  func TestBufConfigMaxLength(t *testing.T) {
   338  	got := 0
   339  	for _, i := range BufConfig {
   340  		got += i
   341  	}
   342  	want := header.MaxIPPacketSize // maximum TCP packet size
   343  	if got < want {
   344  		t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want)
   345  	}
   346  }
   347  
   348  func TestBufConfigFirst(t *testing.T) {
   349  	// The stack assumes that the TCP/IP header is enterily contained in the first view.
   350  	// Therefore, the first view needs to be large enough to contain the maximum TCP/IP
   351  	// header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP).
   352  	want := 120
   353  	got := BufConfig[0]
   354  	if got < want {
   355  		t.Errorf("first view has an invalid size: got %d, want >= %d", got, want)
   356  	}
   357  }
   358  
   359  var capLengthTestCases = []struct {
   360  	comment     string
   361  	config      []int
   362  	n           int
   363  	wantUsed    int
   364  	wantLengths []int
   365  }{
   366  	{
   367  		comment:     "Single slice",
   368  		config:      []int{2},
   369  		n:           1,
   370  		wantUsed:    1,
   371  		wantLengths: []int{1},
   372  	},
   373  	{
   374  		comment:     "Multiple slices",
   375  		config:      []int{1, 2},
   376  		n:           2,
   377  		wantUsed:    2,
   378  		wantLengths: []int{1, 1},
   379  	},
   380  	{
   381  		comment:     "Entire buffer",
   382  		config:      []int{1, 2},
   383  		n:           3,
   384  		wantUsed:    2,
   385  		wantLengths: []int{1, 2},
   386  	},
   387  	{
   388  		comment:     "Entire buffer but not on the last slice",
   389  		config:      []int{1, 2, 3},
   390  		n:           3,
   391  		wantUsed:    2,
   392  		wantLengths: []int{1, 2, 3},
   393  	},
   394  }
   395  
   396  func TestReadVDispatcherCapLength(t *testing.T) {
   397  	for _, c := range capLengthTestCases {
   398  		// fd does not matter for this test.
   399  		d := readVDispatcher{fd: -1, e: &endpoint{}}
   400  		d.views = make([]buffer.View, len(c.config))
   401  		d.iovecs = make([]syscall.Iovec, len(c.config))
   402  		d.allocateViews(c.config)
   403  
   404  		used := d.capViews(c.n, c.config)
   405  		if used != c.wantUsed {
   406  			t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
   407  		}
   408  		lengths := make([]int, len(d.views))
   409  		for i, v := range d.views {
   410  			lengths[i] = len(v)
   411  		}
   412  		if !reflect.DeepEqual(lengths, c.wantLengths) {
   413  			t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
   414  		}
   415  	}
   416  }
   417  
   418  func TestRecvMMsgDispatcherCapLength(t *testing.T) {
   419  	for _, c := range capLengthTestCases {
   420  		d := recvMMsgDispatcher{
   421  			fd:      -1, // fd does not matter for this test.
   422  			e:       &endpoint{},
   423  			views:   make([][]buffer.View, 1),
   424  			iovecs:  make([][]syscall.Iovec, 1),
   425  			msgHdrs: make([]rawfile.MMsgHdr, 1),
   426  		}
   427  
   428  		for i, _ := range d.views {
   429  			d.views[i] = make([]buffer.View, len(c.config))
   430  		}
   431  		for i := range d.iovecs {
   432  			d.iovecs[i] = make([]syscall.Iovec, len(c.config))
   433  		}
   434  		for k, msgHdr := range d.msgHdrs {
   435  			msgHdr.Msg.Iov = &d.iovecs[k][0]
   436  			msgHdr.Msg.Iovlen = uint64(len(c.config))
   437  		}
   438  
   439  		d.allocateViews(c.config)
   440  
   441  		used := d.capViews(0, c.n, c.config)
   442  		if used != c.wantUsed {
   443  			t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
   444  		}
   445  		lengths := make([]int, len(d.views[0]))
   446  		for i, v := range d.views[0] {
   447  			lengths[i] = len(v)
   448  		}
   449  		if !reflect.DeepEqual(lengths, c.wantLengths) {
   450  			t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
   451  		}
   452  
   453  	}
   454  }