github.com/flowerwrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/network/ipv4/ipv4_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ipv4_test
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/hex"
    20  	"math/rand"
    21  	"testing"
    22  
    23  	"github.com/FlowerWrong/netstack/tcpip"
    24  	"github.com/FlowerWrong/netstack/tcpip/buffer"
    25  	"github.com/FlowerWrong/netstack/tcpip/header"
    26  	"github.com/FlowerWrong/netstack/tcpip/link/channel"
    27  	"github.com/FlowerWrong/netstack/tcpip/link/sniffer"
    28  	"github.com/FlowerWrong/netstack/tcpip/network/ipv4"
    29  	"github.com/FlowerWrong/netstack/tcpip/stack"
    30  	"github.com/FlowerWrong/netstack/tcpip/transport/tcp"
    31  	"github.com/FlowerWrong/netstack/tcpip/transport/udp"
    32  	"github.com/FlowerWrong/netstack/waiter"
    33  )
    34  
    35  func TestExcludeBroadcast(t *testing.T) {
    36  	s := stack.New(stack.Options{
    37  		NetworkProtocols:   []stack.NetworkProtocol{ipv4.NewProtocol()},
    38  		TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
    39  	})
    40  
    41  	const defaultMTU = 65536
    42  	ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
    43  	if testing.Verbose() {
    44  		ep = sniffer.New(ep)
    45  	}
    46  	if err := s.CreateNIC(1, ep); err != nil {
    47  		t.Fatalf("CreateNIC failed: %v", err)
    48  	}
    49  
    50  	if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil {
    51  		t.Fatalf("AddAddress failed: %v", err)
    52  	}
    53  
    54  	s.SetRouteTable([]tcpip.Route{{
    55  		Destination: header.IPv4EmptySubnet,
    56  		NIC:         1,
    57  	}})
    58  
    59  	randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53}
    60  
    61  	var wq waiter.Queue
    62  	t.Run("WithoutPrimaryAddress", func(t *testing.T) {
    63  		ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
    64  		if err != nil {
    65  			t.Fatal(err)
    66  		}
    67  		defer ep.Close()
    68  
    69  		// Cannot connect using a broadcast address as the source.
    70  		if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute {
    71  			t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
    72  		}
    73  
    74  		// However, we can bind to a broadcast address to listen.
    75  		if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil {
    76  			t.Errorf("Bind failed: %v", err)
    77  		}
    78  	})
    79  
    80  	t.Run("WithPrimaryAddress", func(t *testing.T) {
    81  		ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
    82  		if err != nil {
    83  			t.Fatal(err)
    84  		}
    85  		defer ep.Close()
    86  
    87  		// Add a valid primary endpoint address, now we can connect.
    88  		if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
    89  			t.Fatalf("AddAddress failed: %v", err)
    90  		}
    91  		if err := ep.Connect(randomAddr); err != nil {
    92  			t.Errorf("Connect failed: %v", err)
    93  		}
    94  	})
    95  }
    96  
    97  // makeHdrAndPayload generates a randomize packet. hdrLength indicates how much
    98  // data should already be in the header before WritePacket. extraLength
    99  // indicates how much extra space should be in the header. The payload is made
   100  // from many Views of the sizes listed in viewSizes.
   101  func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) {
   102  	hdr := buffer.NewPrependable(hdrLength + extraLength)
   103  	hdr.Prepend(hdrLength)
   104  	rand.Read(hdr.View())
   105  
   106  	var views []buffer.View
   107  	totalLength := 0
   108  	for _, s := range viewSizes {
   109  		newView := buffer.NewView(s)
   110  		rand.Read(newView)
   111  		views = append(views, newView)
   112  		totalLength += s
   113  	}
   114  	payload := buffer.NewVectorisedView(totalLength, views)
   115  	return hdr, payload
   116  }
   117  
   118  // comparePayloads compared the contents of all the packets against the contents
   119  // of the source packet.
   120  func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) {
   121  	t.Helper()
   122  	// Make a complete array of the sourcePacketInfo packet.
   123  	source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
   124  	source = append(source, sourcePacketInfo.Header.View()...)
   125  	source = append(source, sourcePacketInfo.Payload.ToView()...)
   126  
   127  	// Make a copy of the IP header, which will be modified in some fields to make
   128  	// an expected header.
   129  	sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...))
   130  	sourceCopy.SetChecksum(0)
   131  	sourceCopy.SetFlagsFragmentOffset(0, 0)
   132  	sourceCopy.SetTotalLength(0)
   133  	var offset uint16
   134  	// Build up an array of the bytes sent.
   135  	var reassembledPayload []byte
   136  	for i, packet := range packets {
   137  		// Confirm that the packet is valid.
   138  		allBytes := packet.Header.View().ToVectorisedView()
   139  		allBytes.Append(packet.Payload)
   140  		ip := header.IPv4(allBytes.ToView())
   141  		if !ip.IsValid(len(ip)) {
   142  			t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
   143  		}
   144  		if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want {
   145  			t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want)
   146  		}
   147  		if got, want := len(ip), int(mtu); got > want {
   148  			t.Errorf("fragment is too large, got %d want %d", got, want)
   149  		}
   150  		if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want {
   151  			t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
   152  		}
   153  		if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want {
   154  			t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
   155  		}
   156  		if i < len(packets)-1 {
   157  			sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
   158  		} else {
   159  			sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
   160  		}
   161  		reassembledPayload = append(reassembledPayload, ip.Payload()...)
   162  		offset += ip.TotalLength() - uint16(ip.HeaderLength())
   163  		// Clear out the checksum and length from the ip because we can't compare
   164  		// it.
   165  		sourceCopy.SetTotalLength(uint16(len(ip)))
   166  		sourceCopy.SetChecksum(0)
   167  		sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
   168  		if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) {
   169  			t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()]))
   170  		}
   171  	}
   172  	expected := source[source.HeaderLength():]
   173  	if !bytes.Equal(reassembledPayload, expected) {
   174  		t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected))
   175  	}
   176  }
   177  
   178  type errorChannel struct {
   179  	*channel.Endpoint
   180  	Ch                    chan packetInfo
   181  	packetCollectorErrors []*tcpip.Error
   182  }
   183  
   184  // newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
   185  // will return successive errors from packetCollectorErrors until the list is
   186  // empty and then return nil each time.
   187  func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
   188  	return &errorChannel{
   189  		Endpoint:              channel.New(size, mtu, linkAddr),
   190  		Ch:                    make(chan packetInfo, size),
   191  		packetCollectorErrors: packetCollectorErrors,
   192  	}
   193  }
   194  
   195  // packetInfo holds all the information about an outbound packet.
   196  type packetInfo struct {
   197  	Header  buffer.Prependable
   198  	Payload buffer.VectorisedView
   199  }
   200  
   201  // Drain removes all outbound packets from the channel and counts them.
   202  func (e *errorChannel) Drain() int {
   203  	c := 0
   204  	for {
   205  		select {
   206  		case <-e.Ch:
   207  			c++
   208  		default:
   209  			return c
   210  		}
   211  	}
   212  }
   213  
   214  // WritePacket stores outbound packets into the channel.
   215  func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
   216  	p := packetInfo{
   217  		Header:  hdr,
   218  		Payload: payload,
   219  	}
   220  
   221  	select {
   222  	case e.Ch <- p:
   223  	default:
   224  	}
   225  
   226  	nextError := (*tcpip.Error)(nil)
   227  	if len(e.packetCollectorErrors) > 0 {
   228  		nextError = e.packetCollectorErrors[0]
   229  		e.packetCollectorErrors = e.packetCollectorErrors[1:]
   230  	}
   231  	return nextError
   232  }
   233  
   234  type context struct {
   235  	stack.Route
   236  	linkEP *errorChannel
   237  }
   238  
   239  func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
   240  	// Make the packet and write it.
   241  	s := stack.New(stack.Options{
   242  		NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
   243  	})
   244  	ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
   245  	s.CreateNIC(1, ep)
   246  	const (
   247  		src = "\x10\x00\x00\x01"
   248  		dst = "\x10\x00\x00\x02"
   249  	)
   250  	s.AddAddress(1, ipv4.ProtocolNumber, src)
   251  	{
   252  		subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
   253  		if err != nil {
   254  			t.Fatal(err)
   255  		}
   256  		s.SetRouteTable([]tcpip.Route{{
   257  			Destination: subnet,
   258  			NIC:         1,
   259  		}})
   260  	}
   261  	r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
   262  	if err != nil {
   263  		t.Fatalf("s.FindRoute got %v, want %v", err, nil)
   264  	}
   265  	return context{
   266  		Route:  r,
   267  		linkEP: ep,
   268  	}
   269  }
   270  
   271  func TestFragmentation(t *testing.T) {
   272  	var manyPayloadViewsSizes [1000]int
   273  	for i := range manyPayloadViewsSizes {
   274  		manyPayloadViewsSizes[i] = 7
   275  	}
   276  	fragTests := []struct {
   277  		description       string
   278  		mtu               uint32
   279  		gso               *stack.GSO
   280  		hdrLength         int
   281  		extraLength       int
   282  		payloadViewsSizes []int
   283  		expectedFrags     int
   284  	}{
   285  		{"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
   286  		{"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
   287  		{"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
   288  		{"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
   289  		{"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
   290  		{"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
   291  		{"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
   292  		{"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
   293  		{"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
   294  	}
   295  
   296  	for _, ft := range fragTests {
   297  		t.Run(ft.description, func(t *testing.T) {
   298  			hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
   299  			source := packetInfo{
   300  				Header: hdr,
   301  				// Save the source payload because WritePacket will modify it.
   302  				Payload: payload.Clone([]buffer.View{}),
   303  			}
   304  			c := buildContext(t, nil, ft.mtu)
   305  			err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42 /* ttl */, false /* useDefaultTTL */)
   306  			if err != nil {
   307  				t.Errorf("err got %v, want %v", err, nil)
   308  			}
   309  
   310  			var results []packetInfo
   311  		L:
   312  			for {
   313  				select {
   314  				case pi := <-c.linkEP.Ch:
   315  					results = append(results, pi)
   316  				default:
   317  					break L
   318  				}
   319  			}
   320  
   321  			if got, want := len(results), ft.expectedFrags; got != want {
   322  				t.Errorf("len(result) got %d, want %d", got, want)
   323  			}
   324  			if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
   325  				t.Errorf("no errors yet len(result) got %d, want %d", got, want)
   326  			}
   327  			compareFragments(t, results, source, ft.mtu)
   328  		})
   329  	}
   330  }
   331  
   332  // TestFragmentationErrors checks that errors are returned from write packet
   333  // correctly.
   334  func TestFragmentationErrors(t *testing.T) {
   335  	fragTests := []struct {
   336  		description           string
   337  		mtu                   uint32
   338  		hdrLength             int
   339  		payloadViewsSizes     []int
   340  		packetCollectorErrors []*tcpip.Error
   341  	}{
   342  		{"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
   343  		{"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
   344  		{"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
   345  		{"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
   346  	}
   347  
   348  	for _, ft := range fragTests {
   349  		t.Run(ft.description, func(t *testing.T) {
   350  			hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
   351  			c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
   352  			err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42 /* ttl */, false /* useDefaultTTL */)
   353  			for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
   354  				if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
   355  					t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
   356  				}
   357  			}
   358  			// We only need to check that last error because all the ones before are
   359  			// nil.
   360  			if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
   361  				t.Errorf("err got %v, want %v", got, want)
   362  			}
   363  			if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
   364  				t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
   365  			}
   366  		})
   367  	}
   368  }