github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/stack/packet_buffer_test.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package stack 15 16 import ( 17 "bytes" 18 "fmt" 19 "testing" 20 21 "github.com/SagerNet/gvisor/pkg/tcpip/buffer" 22 "github.com/SagerNet/gvisor/pkg/tcpip/header" 23 ) 24 25 func TestPacketHeaderPush(t *testing.T) { 26 for _, test := range []struct { 27 name string 28 reserved int 29 link []byte 30 network []byte 31 transport []byte 32 data []byte 33 }{ 34 { 35 name: "construct empty packet", 36 }, 37 { 38 name: "construct link header only packet", 39 reserved: 60, 40 link: makeView(10), 41 }, 42 { 43 name: "construct link and network header only packet", 44 reserved: 60, 45 link: makeView(10), 46 network: makeView(20), 47 }, 48 { 49 name: "construct header only packet", 50 reserved: 60, 51 link: makeView(10), 52 network: makeView(20), 53 transport: makeView(30), 54 }, 55 { 56 name: "construct data only packet", 57 data: makeView(40), 58 }, 59 { 60 name: "construct L3 packet", 61 reserved: 60, 62 network: makeView(20), 63 transport: makeView(30), 64 data: makeView(40), 65 }, 66 { 67 name: "construct L2 packet", 68 reserved: 60, 69 link: makeView(10), 70 network: makeView(20), 71 transport: makeView(30), 72 data: makeView(40), 73 }, 74 } { 75 t.Run(test.name, func(t *testing.T) { 76 pk := NewPacketBuffer(PacketBufferOptions{ 77 ReserveHeaderBytes: test.reserved, 78 // Make a copy of data to make sure our truth data won't be taint by 79 // PacketBuffer. 80 Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), 81 }) 82 83 allHdrSize := len(test.link) + len(test.network) + len(test.transport) 84 85 // Check the initial values for packet. 86 checkInitialPacketBuffer(t, pk, PacketBufferOptions{ 87 ReserveHeaderBytes: test.reserved, 88 Data: buffer.View(test.data).ToVectorisedView(), 89 }) 90 91 // Push headers. 92 if v := test.transport; len(v) > 0 { 93 copy(pk.TransportHeader().Push(len(v)), v) 94 } 95 if v := test.network; len(v) > 0 { 96 copy(pk.NetworkHeader().Push(len(v)), v) 97 } 98 if v := test.link; len(v) > 0 { 99 copy(pk.LinkHeader().Push(len(v)), v) 100 } 101 102 // Check the after values for packet. 103 if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { 104 t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) 105 } 106 if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { 107 t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) 108 } 109 if got, want := pk.HeaderSize(), allHdrSize; got != want { 110 t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) 111 } 112 if got, want := pk.Size(), allHdrSize+len(test.data); got != want { 113 t.Errorf("After pk.Size() = %d, want %d", got, want) 114 } 115 // Check the after state. 116 checkPacketContents(t, "After ", pk, packetContents{ 117 link: test.link, 118 network: test.network, 119 transport: test.transport, 120 data: test.data, 121 }) 122 }) 123 } 124 } 125 126 func TestPacketHeaderConsume(t *testing.T) { 127 for _, test := range []struct { 128 name string 129 data []byte 130 link int 131 network int 132 transport int 133 }{ 134 { 135 name: "parse L2 packet", 136 data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), 137 link: 10, 138 network: 20, 139 transport: 30, 140 }, 141 { 142 name: "parse L3 packet", 143 data: concatViews(makeView(20), makeView(30), makeView(40)), 144 network: 20, 145 transport: 30, 146 }, 147 } { 148 t.Run(test.name, func(t *testing.T) { 149 pk := NewPacketBuffer(PacketBufferOptions{ 150 // Make a copy of data to make sure our truth data won't be taint by 151 // PacketBuffer. 152 Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), 153 }) 154 155 // Check the initial values for packet. 156 checkInitialPacketBuffer(t, pk, PacketBufferOptions{ 157 Data: buffer.View(test.data).ToVectorisedView(), 158 }) 159 160 // Consume headers. 161 if size := test.link; size > 0 { 162 if _, ok := pk.LinkHeader().Consume(size); !ok { 163 t.Fatalf("pk.LinkHeader().Consume() = false, want true") 164 } 165 } 166 if size := test.network; size > 0 { 167 if _, ok := pk.NetworkHeader().Consume(size); !ok { 168 t.Fatalf("pk.NetworkHeader().Consume() = false, want true") 169 } 170 } 171 if size := test.transport; size > 0 { 172 if _, ok := pk.TransportHeader().Consume(size); !ok { 173 t.Fatalf("pk.TransportHeader().Consume() = false, want true") 174 } 175 } 176 177 allHdrSize := test.link + test.network + test.transport 178 179 // Check the after values for packet. 180 if got, want := pk.ReservedHeaderBytes(), 0; got != want { 181 t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) 182 } 183 if got, want := pk.AvailableHeaderBytes(), 0; got != want { 184 t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) 185 } 186 if got, want := pk.HeaderSize(), allHdrSize; got != want { 187 t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) 188 } 189 if got, want := pk.Size(), len(test.data); got != want { 190 t.Errorf("After pk.Size() = %d, want %d", got, want) 191 } 192 // Check the after state of pk. 193 checkPacketContents(t, "After ", pk, packetContents{ 194 link: test.data[:test.link], 195 network: test.data[test.link:][:test.network], 196 transport: test.data[test.link+test.network:][:test.transport], 197 data: test.data[allHdrSize:], 198 }) 199 }) 200 } 201 } 202 203 func TestPacketHeaderConsumeDataTooShort(t *testing.T) { 204 data := makeView(10) 205 206 pk := NewPacketBuffer(PacketBufferOptions{ 207 // Make a copy of data to make sure our truth data won't be taint by 208 // PacketBuffer. 209 Data: buffer.NewViewFromBytes(data).ToVectorisedView(), 210 }) 211 212 // Consume should fail if pkt.Data is too short. 213 if _, ok := pk.LinkHeader().Consume(11); ok { 214 t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") 215 } 216 if _, ok := pk.NetworkHeader().Consume(11); ok { 217 t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") 218 } 219 if _, ok := pk.TransportHeader().Consume(11); ok { 220 t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") 221 } 222 223 // Check packet should look the same as initial packet. 224 checkInitialPacketBuffer(t, pk, PacketBufferOptions{ 225 Data: buffer.View(data).ToVectorisedView(), 226 }) 227 } 228 229 // This is a very obscure use-case seen in the code that verifies packets 230 // before sending them out. It tries to parse the headers to verify. 231 // PacketHeader was initially not designed to mix Push() and Consume(), but it 232 // works and it's been relied upon. Include a test here. 233 func TestPacketHeaderPushConsumeMixed(t *testing.T) { 234 link := makeView(10) 235 network := makeView(20) 236 data := makeView(30) 237 238 initData := append([]byte(nil), network...) 239 initData = append(initData, data...) 240 pk := NewPacketBuffer(PacketBufferOptions{ 241 ReserveHeaderBytes: len(link), 242 Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), 243 }) 244 245 // 1. Consume network header 246 gotNetwork, ok := pk.NetworkHeader().Consume(len(network)) 247 if !ok { 248 t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network)) 249 } 250 checkViewEqual(t, "gotNetwork", gotNetwork, network) 251 252 // 2. Push link header 253 copy(pk.LinkHeader().Push(len(link)), link) 254 255 checkPacketContents(t, "" /* prefix */, pk, packetContents{ 256 link: link, 257 network: network, 258 data: data, 259 }) 260 } 261 262 func TestPacketHeaderPushConsumeMixedTooLong(t *testing.T) { 263 link := makeView(10) 264 network := makeView(20) 265 data := makeView(30) 266 267 initData := concatViews(network, data) 268 pk := NewPacketBuffer(PacketBufferOptions{ 269 ReserveHeaderBytes: len(link), 270 Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), 271 }) 272 273 // 1. Push link header 274 copy(pk.LinkHeader().Push(len(link)), link) 275 276 checkPacketContents(t, "" /* prefix */, pk, packetContents{ 277 link: link, 278 data: initData, 279 }) 280 281 // 2. Consume network header, with a number of bytes too large. 282 gotNetwork, ok := pk.NetworkHeader().Consume(len(initData) + 1) 283 if ok { 284 t.Fatalf("pk.NetworkHeader().Consume(%d) = %q, true; want _, false", len(initData)+1, gotNetwork) 285 } 286 287 checkPacketContents(t, "" /* prefix */, pk, packetContents{ 288 link: link, 289 data: initData, 290 }) 291 } 292 293 func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { 294 const headerSize = 10 295 296 pk := NewPacketBuffer(PacketBufferOptions{ 297 ReserveHeaderBytes: headerSize * int(numHeaderType), 298 }) 299 300 for _, h := range []PacketHeader{ 301 pk.TransportHeader(), 302 pk.NetworkHeader(), 303 pk.LinkHeader(), 304 } { 305 t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { 306 h.Push(headerSize) 307 308 defer func() { recover() }() 309 h.Push(headerSize) 310 t.Fatal("Second push should have panicked") 311 }) 312 } 313 } 314 315 func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { 316 const headerSize = 10 317 318 pk := NewPacketBuffer(PacketBufferOptions{ 319 Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), 320 }) 321 322 for _, h := range []PacketHeader{ 323 pk.LinkHeader(), 324 pk.NetworkHeader(), 325 pk.TransportHeader(), 326 } { 327 t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { 328 if _, ok := h.Consume(headerSize); !ok { 329 t.Fatal("First consume should succeed") 330 } 331 332 defer func() { recover() }() 333 h.Consume(headerSize) 334 t.Fatal("Second consume should have panicked") 335 }) 336 } 337 } 338 339 func TestPacketHeaderPushThenConsumePanics(t *testing.T) { 340 const headerSize = 10 341 342 pk := NewPacketBuffer(PacketBufferOptions{ 343 ReserveHeaderBytes: headerSize * int(numHeaderType), 344 }) 345 346 for _, h := range []PacketHeader{ 347 pk.TransportHeader(), 348 pk.NetworkHeader(), 349 pk.LinkHeader(), 350 } { 351 t.Run(h.typ.String(), func(t *testing.T) { 352 h.Push(headerSize) 353 354 defer func() { recover() }() 355 h.Consume(headerSize) 356 t.Fatal("Consume should have panicked") 357 }) 358 } 359 } 360 361 func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { 362 const headerSize = 10 363 364 pk := NewPacketBuffer(PacketBufferOptions{ 365 Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), 366 }) 367 368 for _, h := range []PacketHeader{ 369 pk.LinkHeader(), 370 pk.NetworkHeader(), 371 pk.TransportHeader(), 372 } { 373 t.Run(h.typ.String(), func(t *testing.T) { 374 h.Consume(headerSize) 375 376 defer func() { recover() }() 377 h.Push(headerSize) 378 t.Fatal("Push should have panicked") 379 }) 380 } 381 } 382 383 func TestPacketBufferData(t *testing.T) { 384 for _, tc := range []struct { 385 name string 386 makePkt func(*testing.T) *PacketBuffer 387 data string 388 }{ 389 { 390 name: "inbound packet", 391 makePkt: func(*testing.T) *PacketBuffer { 392 pkt := NewPacketBuffer(PacketBufferOptions{ 393 Data: vv("aabbbbccccccDATA"), 394 }) 395 pkt.LinkHeader().Consume(2) 396 pkt.NetworkHeader().Consume(4) 397 pkt.TransportHeader().Consume(6) 398 return pkt 399 }, 400 data: "DATA", 401 }, 402 { 403 name: "outbound packet", 404 makePkt: func(*testing.T) *PacketBuffer { 405 pkt := NewPacketBuffer(PacketBufferOptions{ 406 ReserveHeaderBytes: 12, 407 Data: vv("DATA"), 408 }) 409 copy(pkt.TransportHeader().Push(6), []byte("cccccc")) 410 copy(pkt.NetworkHeader().Push(4), []byte("bbbb")) 411 copy(pkt.LinkHeader().Push(2), []byte("aa")) 412 return pkt 413 }, 414 data: "DATA", 415 }, 416 } { 417 t.Run(tc.name, func(t *testing.T) { 418 // PullUp 419 for _, n := range []int{1, len(tc.data)} { 420 t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) { 421 pkt := tc.makePkt(t) 422 v, ok := pkt.Data().PullUp(n) 423 wantV := []byte(tc.data)[:n] 424 if !ok || !bytes.Equal(v, wantV) { 425 t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV) 426 } 427 }) 428 } 429 t.Run("PullUpOutOfBounds", func(t *testing.T) { 430 n := len(tc.data) + 1 431 pkt := tc.makePkt(t) 432 v, ok := pkt.Data().PullUp(n) 433 if ok || v != nil { 434 t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok) 435 } 436 }) 437 438 // DeleteFront 439 for _, n := range []int{1, len(tc.data)} { 440 t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { 441 pkt := tc.makePkt(t) 442 pkt.Data().DeleteFront(n) 443 444 checkData(t, pkt, []byte(tc.data)[n:]) 445 }) 446 } 447 448 // CapLength 449 for _, n := range []int{0, 1, len(tc.data)} { 450 t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) { 451 pkt := tc.makePkt(t) 452 pkt.Data().CapLength(n) 453 454 want := []byte(tc.data) 455 if n < len(want) { 456 want = want[:n] 457 } 458 checkData(t, pkt, want) 459 }) 460 } 461 462 // Views 463 t.Run("Views", func(t *testing.T) { 464 pkt := tc.makePkt(t) 465 checkData(t, pkt, []byte(tc.data)) 466 }) 467 468 // AppendView 469 t.Run("AppendView", func(t *testing.T) { 470 s := "APPEND" 471 472 pkt := tc.makePkt(t) 473 pkt.Data().AppendView(buffer.View(s)) 474 475 checkData(t, pkt, []byte(tc.data+s)) 476 }) 477 478 // ReadFromVV 479 for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { 480 t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { 481 s := "TO READ" 482 srcVV := vv(s, s) 483 s += s 484 485 pkt := tc.makePkt(t) 486 pkt.Data().ReadFromVV(&srcVV, n) 487 488 if n < len(s) { 489 s = s[:n] 490 } 491 checkData(t, pkt, []byte(tc.data+s)) 492 }) 493 } 494 495 // ExtractVV 496 t.Run("ExtractVV", func(t *testing.T) { 497 pkt := tc.makePkt(t) 498 extractedVV := pkt.Data().ExtractVV() 499 500 got := extractedVV.ToOwnedView() 501 want := []byte(tc.data) 502 if !bytes.Equal(got, want) { 503 t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) 504 } 505 }) 506 }) 507 } 508 } 509 510 type packetContents struct { 511 link buffer.View 512 network buffer.View 513 transport buffer.View 514 data buffer.View 515 } 516 517 func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) { 518 t.Helper() 519 // Headers. 520 checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link) 521 checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network) 522 checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport) 523 // Data. 524 checkData(t, pk, want.data) 525 // Whole packet. 526 checkViewEqual(t, prefix+"pk.Views()", 527 concatViews(pk.Views()...), 528 concatViews(want.link, want.network, want.transport, want.data)) 529 // PayloadSince. 530 checkViewEqual(t, prefix+"PayloadSince(LinkHeader)", 531 PayloadSince(pk.LinkHeader()), 532 concatViews(want.link, want.network, want.transport, want.data)) 533 checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)", 534 PayloadSince(pk.NetworkHeader()), 535 concatViews(want.network, want.transport, want.data)) 536 checkViewEqual(t, prefix+"PayloadSince(TransportHeader)", 537 PayloadSince(pk.TransportHeader()), 538 concatViews(want.transport, want.data)) 539 } 540 541 func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { 542 t.Helper() 543 reserved := opts.ReserveHeaderBytes 544 if got, want := pk.ReservedHeaderBytes(), reserved; got != want { 545 t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) 546 } 547 if got, want := pk.AvailableHeaderBytes(), reserved; got != want { 548 t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) 549 } 550 if got, want := pk.HeaderSize(), 0; got != want { 551 t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) 552 } 553 data := opts.Data.ToView() 554 if got, want := pk.Size(), len(data); got != want { 555 t.Errorf("Initial pk.Size() = %d, want %d", got, want) 556 } 557 checkPacketContents(t, "Initial ", pk, packetContents{ 558 data: data, 559 }) 560 } 561 562 func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { 563 t.Helper() 564 checkViewEqual(t, name+".View()", h.View(), want) 565 } 566 567 func checkViewEqual(t *testing.T, what string, got, want buffer.View) { 568 t.Helper() 569 if !bytes.Equal(got, want) { 570 t.Errorf("%s = %x, want %x", what, got, want) 571 } 572 } 573 574 func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { 575 t.Helper() 576 if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { 577 t.Errorf("pkt.Data().Views() = 0x%x, want 0x%x", got, want) 578 } 579 if got := pkt.Data().Size(); got != len(want) { 580 t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) 581 } 582 583 t.Run("AsRange", func(t *testing.T) { 584 // Full range 585 checkRange(t, pkt.Data().AsRange(), want) 586 587 // SubRange 588 for _, off := range []int{0, 1, len(want), len(want) + 1} { 589 t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) { 590 // Empty when off is greater than the size of range. 591 var sub []byte 592 if off < len(want) { 593 sub = want[off:] 594 } 595 checkRange(t, pkt.Data().AsRange().SubRange(off), sub) 596 }) 597 } 598 599 // Capped 600 for _, n := range []int{0, 1, len(want), len(want) + 1} { 601 t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) { 602 sub := want 603 if n < len(sub) { 604 sub = sub[:n] 605 } 606 checkRange(t, pkt.Data().AsRange().Capped(n), sub) 607 }) 608 } 609 }) 610 } 611 612 func checkRange(t *testing.T, r Range, data []byte) { 613 if got, want := r.Size(), len(data); got != want { 614 t.Errorf("r.Size() = %d, want %d", got, want) 615 } 616 if got := r.AsView(); !bytes.Equal(got, data) { 617 t.Errorf("r.AsView() = %x, want %x", got, data) 618 } 619 if got := r.ToOwnedView(); !bytes.Equal(got, data) { 620 t.Errorf("r.ToOwnedView() = %x, want %x", got, data) 621 } 622 if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want { 623 t.Errorf("r.Checksum() = %x, want %x", got, want) 624 } 625 } 626 627 func vv(pieces ...string) buffer.VectorisedView { 628 var views []buffer.View 629 var size int 630 for _, p := range pieces { 631 v := buffer.View([]byte(p)) 632 size += len(v) 633 views = append(views, v) 634 } 635 return buffer.NewVectorisedView(size, views) 636 } 637 638 func makeView(size int) buffer.View { 639 b := byte(size) 640 return bytes.Repeat([]byte{b}, size) 641 } 642 643 func concatViews(views ...buffer.View) buffer.View { 644 var all buffer.View 645 for _, v := range views { 646 all = append(all, v...) 647 } 648 return all 649 }