github.com/flowerwrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/network/ipv4/ipv4_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package ipv4_test 16 17 import ( 18 "bytes" 19 "encoding/hex" 20 "math/rand" 21 "testing" 22 23 "github.com/FlowerWrong/netstack/tcpip" 24 "github.com/FlowerWrong/netstack/tcpip/buffer" 25 "github.com/FlowerWrong/netstack/tcpip/header" 26 "github.com/FlowerWrong/netstack/tcpip/link/channel" 27 "github.com/FlowerWrong/netstack/tcpip/link/sniffer" 28 "github.com/FlowerWrong/netstack/tcpip/network/ipv4" 29 "github.com/FlowerWrong/netstack/tcpip/stack" 30 "github.com/FlowerWrong/netstack/tcpip/transport/tcp" 31 "github.com/FlowerWrong/netstack/tcpip/transport/udp" 32 "github.com/FlowerWrong/netstack/waiter" 33 ) 34 35 func TestExcludeBroadcast(t *testing.T) { 36 s := stack.New(stack.Options{ 37 NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, 38 TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, 39 }) 40 41 const defaultMTU = 65536 42 ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) 43 if testing.Verbose() { 44 ep = sniffer.New(ep) 45 } 46 if err := s.CreateNIC(1, ep); err != nil { 47 t.Fatalf("CreateNIC failed: %v", err) 48 } 49 50 if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { 51 t.Fatalf("AddAddress failed: %v", err) 52 } 53 54 s.SetRouteTable([]tcpip.Route{{ 55 Destination: header.IPv4EmptySubnet, 56 NIC: 1, 57 }}) 58 59 randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} 60 61 var wq waiter.Queue 62 t.Run("WithoutPrimaryAddress", func(t *testing.T) { 63 ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 64 if err != nil { 65 t.Fatal(err) 66 } 67 defer ep.Close() 68 69 // Cannot connect using a broadcast address as the source. 70 if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { 71 t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) 72 } 73 74 // However, we can bind to a broadcast address to listen. 75 if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { 76 t.Errorf("Bind failed: %v", err) 77 } 78 }) 79 80 t.Run("WithPrimaryAddress", func(t *testing.T) { 81 ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 82 if err != nil { 83 t.Fatal(err) 84 } 85 defer ep.Close() 86 87 // Add a valid primary endpoint address, now we can connect. 88 if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { 89 t.Fatalf("AddAddress failed: %v", err) 90 } 91 if err := ep.Connect(randomAddr); err != nil { 92 t.Errorf("Connect failed: %v", err) 93 } 94 }) 95 } 96 97 // makeHdrAndPayload generates a randomize packet. hdrLength indicates how much 98 // data should already be in the header before WritePacket. extraLength 99 // indicates how much extra space should be in the header. The payload is made 100 // from many Views of the sizes listed in viewSizes. 101 func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { 102 hdr := buffer.NewPrependable(hdrLength + extraLength) 103 hdr.Prepend(hdrLength) 104 rand.Read(hdr.View()) 105 106 var views []buffer.View 107 totalLength := 0 108 for _, s := range viewSizes { 109 newView := buffer.NewView(s) 110 rand.Read(newView) 111 views = append(views, newView) 112 totalLength += s 113 } 114 payload := buffer.NewVectorisedView(totalLength, views) 115 return hdr, payload 116 } 117 118 // comparePayloads compared the contents of all the packets against the contents 119 // of the source packet. 120 func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { 121 t.Helper() 122 // Make a complete array of the sourcePacketInfo packet. 123 source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) 124 source = append(source, sourcePacketInfo.Header.View()...) 125 source = append(source, sourcePacketInfo.Payload.ToView()...) 126 127 // Make a copy of the IP header, which will be modified in some fields to make 128 // an expected header. 129 sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) 130 sourceCopy.SetChecksum(0) 131 sourceCopy.SetFlagsFragmentOffset(0, 0) 132 sourceCopy.SetTotalLength(0) 133 var offset uint16 134 // Build up an array of the bytes sent. 135 var reassembledPayload []byte 136 for i, packet := range packets { 137 // Confirm that the packet is valid. 138 allBytes := packet.Header.View().ToVectorisedView() 139 allBytes.Append(packet.Payload) 140 ip := header.IPv4(allBytes.ToView()) 141 if !ip.IsValid(len(ip)) { 142 t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) 143 } 144 if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { 145 t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) 146 } 147 if got, want := len(ip), int(mtu); got > want { 148 t.Errorf("fragment is too large, got %d want %d", got, want) 149 } 150 if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { 151 t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) 152 } 153 if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { 154 t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) 155 } 156 if i < len(packets)-1 { 157 sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) 158 } else { 159 sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) 160 } 161 reassembledPayload = append(reassembledPayload, ip.Payload()...) 162 offset += ip.TotalLength() - uint16(ip.HeaderLength()) 163 // Clear out the checksum and length from the ip because we can't compare 164 // it. 165 sourceCopy.SetTotalLength(uint16(len(ip))) 166 sourceCopy.SetChecksum(0) 167 sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) 168 if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { 169 t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) 170 } 171 } 172 expected := source[source.HeaderLength():] 173 if !bytes.Equal(reassembledPayload, expected) { 174 t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) 175 } 176 } 177 178 type errorChannel struct { 179 *channel.Endpoint 180 Ch chan packetInfo 181 packetCollectorErrors []*tcpip.Error 182 } 183 184 // newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket 185 // will return successive errors from packetCollectorErrors until the list is 186 // empty and then return nil each time. 187 func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { 188 return &errorChannel{ 189 Endpoint: channel.New(size, mtu, linkAddr), 190 Ch: make(chan packetInfo, size), 191 packetCollectorErrors: packetCollectorErrors, 192 } 193 } 194 195 // packetInfo holds all the information about an outbound packet. 196 type packetInfo struct { 197 Header buffer.Prependable 198 Payload buffer.VectorisedView 199 } 200 201 // Drain removes all outbound packets from the channel and counts them. 202 func (e *errorChannel) Drain() int { 203 c := 0 204 for { 205 select { 206 case <-e.Ch: 207 c++ 208 default: 209 return c 210 } 211 } 212 } 213 214 // WritePacket stores outbound packets into the channel. 215 func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { 216 p := packetInfo{ 217 Header: hdr, 218 Payload: payload, 219 } 220 221 select { 222 case e.Ch <- p: 223 default: 224 } 225 226 nextError := (*tcpip.Error)(nil) 227 if len(e.packetCollectorErrors) > 0 { 228 nextError = e.packetCollectorErrors[0] 229 e.packetCollectorErrors = e.packetCollectorErrors[1:] 230 } 231 return nextError 232 } 233 234 type context struct { 235 stack.Route 236 linkEP *errorChannel 237 } 238 239 func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { 240 // Make the packet and write it. 241 s := stack.New(stack.Options{ 242 NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, 243 }) 244 ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) 245 s.CreateNIC(1, ep) 246 const ( 247 src = "\x10\x00\x00\x01" 248 dst = "\x10\x00\x00\x02" 249 ) 250 s.AddAddress(1, ipv4.ProtocolNumber, src) 251 { 252 subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) 253 if err != nil { 254 t.Fatal(err) 255 } 256 s.SetRouteTable([]tcpip.Route{{ 257 Destination: subnet, 258 NIC: 1, 259 }}) 260 } 261 r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) 262 if err != nil { 263 t.Fatalf("s.FindRoute got %v, want %v", err, nil) 264 } 265 return context{ 266 Route: r, 267 linkEP: ep, 268 } 269 } 270 271 func TestFragmentation(t *testing.T) { 272 var manyPayloadViewsSizes [1000]int 273 for i := range manyPayloadViewsSizes { 274 manyPayloadViewsSizes[i] = 7 275 } 276 fragTests := []struct { 277 description string 278 mtu uint32 279 gso *stack.GSO 280 hdrLength int 281 extraLength int 282 payloadViewsSizes []int 283 expectedFrags int 284 }{ 285 {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, 286 {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, 287 {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, 288 {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, 289 {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, 290 {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, 291 {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, 292 {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, 293 {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, 294 } 295 296 for _, ft := range fragTests { 297 t.Run(ft.description, func(t *testing.T) { 298 hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) 299 source := packetInfo{ 300 Header: hdr, 301 // Save the source payload because WritePacket will modify it. 302 Payload: payload.Clone([]buffer.View{}), 303 } 304 c := buildContext(t, nil, ft.mtu) 305 err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42 /* ttl */, false /* useDefaultTTL */) 306 if err != nil { 307 t.Errorf("err got %v, want %v", err, nil) 308 } 309 310 var results []packetInfo 311 L: 312 for { 313 select { 314 case pi := <-c.linkEP.Ch: 315 results = append(results, pi) 316 default: 317 break L 318 } 319 } 320 321 if got, want := len(results), ft.expectedFrags; got != want { 322 t.Errorf("len(result) got %d, want %d", got, want) 323 } 324 if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { 325 t.Errorf("no errors yet len(result) got %d, want %d", got, want) 326 } 327 compareFragments(t, results, source, ft.mtu) 328 }) 329 } 330 } 331 332 // TestFragmentationErrors checks that errors are returned from write packet 333 // correctly. 334 func TestFragmentationErrors(t *testing.T) { 335 fragTests := []struct { 336 description string 337 mtu uint32 338 hdrLength int 339 payloadViewsSizes []int 340 packetCollectorErrors []*tcpip.Error 341 }{ 342 {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, 343 {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, 344 {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, 345 {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, 346 } 347 348 for _, ft := range fragTests { 349 t.Run(ft.description, func(t *testing.T) { 350 hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) 351 c := buildContext(t, ft.packetCollectorErrors, ft.mtu) 352 err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42 /* ttl */, false /* useDefaultTTL */) 353 for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { 354 if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { 355 t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) 356 } 357 } 358 // We only need to check that last error because all the ones before are 359 // nil. 360 if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { 361 t.Errorf("err got %v, want %v", got, want) 362 } 363 if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { 364 t.Errorf("after linkEP error len(result) got %d, want %d", got, want) 365 } 366 }) 367 } 368 }