github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package fragmentation 16 17 import ( 18 "errors" 19 "testing" 20 "time" 21 22 "github.com/google/go-cmp/cmp" 23 "github.com/SagerNet/gvisor/pkg/tcpip/buffer" 24 "github.com/SagerNet/gvisor/pkg/tcpip/faketime" 25 "github.com/SagerNet/gvisor/pkg/tcpip/network/internal/testutil" 26 "github.com/SagerNet/gvisor/pkg/tcpip/stack" 27 ) 28 29 // reassembleTimeout is dummy timeout used for testing, where the clock never 30 // advances. 31 const reassembleTimeout = 1 32 33 // vv is a helper to build VectorisedView from different strings. 34 func vv(size int, pieces ...string) buffer.VectorisedView { 35 views := make([]buffer.View, len(pieces)) 36 for i, p := range pieces { 37 views[i] = []byte(p) 38 } 39 40 return buffer.NewVectorisedView(size, views) 41 } 42 43 func pkt(size int, pieces ...string) *stack.PacketBuffer { 44 return stack.NewPacketBuffer(stack.PacketBufferOptions{ 45 Data: vv(size, pieces...), 46 }) 47 } 48 49 type processInput struct { 50 id FragmentID 51 first uint16 52 last uint16 53 more bool 54 proto uint8 55 pkt *stack.PacketBuffer 56 } 57 58 type processOutput struct { 59 vv buffer.VectorisedView 60 proto uint8 61 done bool 62 } 63 64 var processTestCases = []struct { 65 comment string 66 in []processInput 67 out []processOutput 68 }{ 69 { 70 comment: "One ID", 71 in: []processInput{ 72 {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, 73 {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, 74 }, 75 out: []processOutput{ 76 {vv: buffer.VectorisedView{}, done: false}, 77 {vv: vv(4, "01", "23"), done: true}, 78 }, 79 }, 80 { 81 comment: "Next Header protocol mismatch", 82 in: []processInput{ 83 {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, 84 {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, 85 }, 86 out: []processOutput{ 87 {vv: buffer.VectorisedView{}, done: false}, 88 {vv: vv(4, "01", "23"), proto: 6, done: true}, 89 }, 90 }, 91 { 92 comment: "Two IDs", 93 in: []processInput{ 94 {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, 95 {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, 96 {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, 97 {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, 98 }, 99 out: []processOutput{ 100 {vv: buffer.VectorisedView{}, done: false}, 101 {vv: buffer.VectorisedView{}, done: false}, 102 {vv: vv(4, "ab", "cd"), done: true}, 103 {vv: vv(4, "01", "23"), done: true}, 104 }, 105 }, 106 } 107 108 func TestFragmentationProcess(t *testing.T) { 109 for _, c := range processTestCases { 110 t.Run(c.comment, func(t *testing.T) { 111 f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) 112 firstFragmentProto := c.in[0].proto 113 for i, in := range c.in { 114 resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) 115 if err != nil { 116 t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", 117 in.id, in.first, in.last, in.more, in.proto, in.pkt, err) 118 } 119 if done != c.out[i].done { 120 t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", 121 in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) 122 } 123 if c.out[i].done { 124 if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data().AsRange().ToOwnedView()); diff != "" { 125 t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", 126 in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) 127 } 128 if firstFragmentProto != proto { 129 t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", 130 in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) 131 } 132 if _, ok := f.reassemblers[in.id]; ok { 133 t.Errorf("Process(%d) did not remove buffer from reassemblers", i) 134 } 135 for n := f.rList.Front(); n != nil; n = n.Next() { 136 if n.id == in.id { 137 t.Errorf("Process(%d) did not remove buffer from rList", i) 138 } 139 } 140 } 141 } 142 }) 143 } 144 } 145 146 func TestReassemblingTimeout(t *testing.T) { 147 const ( 148 reassemblyTimeout = time.Millisecond 149 protocol = 0xff 150 ) 151 152 type fragment struct { 153 first uint16 154 last uint16 155 more bool 156 data string 157 } 158 159 type event struct { 160 // name is a nickname of this event. 161 name string 162 163 // clockAdvance is a duration to advance the clock. The clock advances 164 // before a fragment specified in the fragment field is processed. 165 clockAdvance time.Duration 166 167 // fragment is a fragment to process. This can be nil if there is no 168 // fragment to process. 169 fragment *fragment 170 171 // expectDone is true if the fragmentation instance should report the 172 // reassembly is done after the fragment is processd. 173 expectDone bool 174 175 // memSizeAfterEvent is the expected memory size of the fragmentation 176 // instance after the event. 177 memSizeAfterEvent int 178 } 179 180 memSizeOfFrags := func(frags ...*fragment) int { 181 var size int 182 for _, frag := range frags { 183 size += pkt(len(frag.data), frag.data).MemSize() 184 } 185 return size 186 } 187 188 half1 := &fragment{first: 0, last: 0, more: true, data: "0"} 189 half2 := &fragment{first: 1, last: 1, more: false, data: "1"} 190 191 tests := []struct { 192 name string 193 events []event 194 }{ 195 { 196 name: "half1 and half2 are reassembled successfully", 197 events: []event{ 198 { 199 name: "half1", 200 fragment: half1, 201 expectDone: false, 202 memSizeAfterEvent: memSizeOfFrags(half1), 203 }, 204 { 205 name: "half2", 206 fragment: half2, 207 expectDone: true, 208 memSizeAfterEvent: 0, 209 }, 210 }, 211 }, 212 { 213 name: "half1 timeout, half2 timeout", 214 events: []event{ 215 { 216 name: "half1", 217 fragment: half1, 218 expectDone: false, 219 memSizeAfterEvent: memSizeOfFrags(half1), 220 }, 221 { 222 name: "half1 just before reassembly timeout", 223 clockAdvance: reassemblyTimeout - 1, 224 memSizeAfterEvent: memSizeOfFrags(half1), 225 }, 226 { 227 name: "half1 reassembly timeout", 228 clockAdvance: 1, 229 memSizeAfterEvent: 0, 230 }, 231 { 232 name: "half2", 233 fragment: half2, 234 expectDone: false, 235 memSizeAfterEvent: memSizeOfFrags(half2), 236 }, 237 { 238 name: "half2 just before reassembly timeout", 239 clockAdvance: reassemblyTimeout - 1, 240 memSizeAfterEvent: memSizeOfFrags(half2), 241 }, 242 { 243 name: "half2 reassembly timeout", 244 clockAdvance: 1, 245 memSizeAfterEvent: 0, 246 }, 247 }, 248 }, 249 } 250 for _, test := range tests { 251 t.Run(test.name, func(t *testing.T) { 252 clock := faketime.NewManualClock() 253 f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) 254 for _, event := range test.events { 255 clock.Advance(event.clockAdvance) 256 if frag := event.fragment; frag != nil { 257 _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) 258 if err != nil { 259 t.Fatalf("%s: f.Process failed: %s", event.name, err) 260 } 261 if done != event.expectDone { 262 t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) 263 } 264 } 265 if got, want := f.memSize, event.memSizeAfterEvent; got != want { 266 t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want) 267 } 268 } 269 }) 270 } 271 } 272 273 func TestMemoryLimits(t *testing.T) { 274 lowLimit := pkt(1, "0").MemSize() 275 highLimit := 3 * lowLimit // Allow at most 3 such packets. 276 f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) 277 // Send first fragment with id = 0. 278 if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { 279 t.Fatal(err) 280 } 281 // Send first fragment with id = 1. 282 if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil { 283 t.Fatal(err) 284 } 285 // Send first fragment with id = 2. 286 if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil { 287 t.Fatal(err) 288 } 289 290 // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be 291 // evicted. 292 if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil { 293 t.Fatal(err) 294 } 295 296 if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { 297 t.Errorf("Memory limits are not respected: id=0 has not been evicted.") 298 } 299 if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok { 300 t.Errorf("Memory limits are not respected: id=1 has not been evicted.") 301 } 302 if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok { 303 t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") 304 } 305 } 306 307 func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { 308 memSize := pkt(1, "0").MemSize() 309 f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) 310 // Send first fragment with id = 0. 311 if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { 312 t.Fatal(err) 313 } 314 // Send the same packet again. 315 if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { 316 t.Fatal(err) 317 } 318 319 if got, want := f.memSize, memSize; got != want { 320 t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) 321 } 322 } 323 324 func TestErrors(t *testing.T) { 325 tests := []struct { 326 name string 327 blockSize uint16 328 first uint16 329 last uint16 330 more bool 331 data string 332 err error 333 }{ 334 { 335 name: "exact block size without more", 336 blockSize: 2, 337 first: 2, 338 last: 3, 339 more: false, 340 data: "01", 341 }, 342 { 343 name: "exact block size with more", 344 blockSize: 2, 345 first: 2, 346 last: 3, 347 more: true, 348 data: "01", 349 }, 350 { 351 name: "exact block size with more and extra data", 352 blockSize: 2, 353 first: 2, 354 last: 3, 355 more: true, 356 data: "012", 357 err: ErrInvalidArgs, 358 }, 359 { 360 name: "exact block size with more and too little data", 361 blockSize: 2, 362 first: 2, 363 last: 3, 364 more: true, 365 data: "0", 366 err: ErrInvalidArgs, 367 }, 368 { 369 name: "not exact block size with more", 370 blockSize: 2, 371 first: 2, 372 last: 2, 373 more: true, 374 data: "0", 375 err: ErrInvalidArgs, 376 }, 377 { 378 name: "not exact block size without more", 379 blockSize: 2, 380 first: 2, 381 last: 2, 382 more: false, 383 data: "0", 384 }, 385 { 386 name: "first not a multiple of block size", 387 blockSize: 2, 388 first: 3, 389 last: 4, 390 more: true, 391 data: "01", 392 err: ErrInvalidArgs, 393 }, 394 { 395 name: "first more than last", 396 blockSize: 2, 397 first: 4, 398 last: 3, 399 more: true, 400 data: "01", 401 err: ErrInvalidArgs, 402 }, 403 } 404 405 for _, test := range tests { 406 t.Run(test.name, func(t *testing.T) { 407 f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) 408 _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) 409 if !errors.Is(err, test.err) { 410 t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) 411 } 412 if done { 413 t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data) 414 } 415 }) 416 } 417 } 418 419 type fragmentInfo struct { 420 remaining int 421 copied int 422 offset int 423 more bool 424 } 425 426 func TestPacketFragmenter(t *testing.T) { 427 const ( 428 reserve = 60 429 proto = 0 430 ) 431 432 tests := []struct { 433 name string 434 fragmentPayloadLen uint32 435 transportHeaderLen int 436 payloadSize int 437 wantFragments []fragmentInfo 438 }{ 439 { 440 name: "Packet exactly fits in MTU", 441 fragmentPayloadLen: 1280, 442 transportHeaderLen: 0, 443 payloadSize: 1280, 444 wantFragments: []fragmentInfo{ 445 {remaining: 0, copied: 1280, offset: 0, more: false}, 446 }, 447 }, 448 { 449 name: "Packet exactly does not fit in MTU", 450 fragmentPayloadLen: 1000, 451 transportHeaderLen: 0, 452 payloadSize: 1001, 453 wantFragments: []fragmentInfo{ 454 {remaining: 1, copied: 1000, offset: 0, more: true}, 455 {remaining: 0, copied: 1, offset: 1000, more: false}, 456 }, 457 }, 458 { 459 name: "Packet has a transport header", 460 fragmentPayloadLen: 560, 461 transportHeaderLen: 40, 462 payloadSize: 560, 463 wantFragments: []fragmentInfo{ 464 {remaining: 1, copied: 560, offset: 0, more: true}, 465 {remaining: 0, copied: 40, offset: 560, more: false}, 466 }, 467 }, 468 { 469 name: "Packet has a huge transport header", 470 fragmentPayloadLen: 500, 471 transportHeaderLen: 1300, 472 payloadSize: 500, 473 wantFragments: []fragmentInfo{ 474 {remaining: 3, copied: 500, offset: 0, more: true}, 475 {remaining: 2, copied: 500, offset: 500, more: true}, 476 {remaining: 1, copied: 500, offset: 1000, more: true}, 477 {remaining: 0, copied: 300, offset: 1500, more: false}, 478 }, 479 }, 480 } 481 482 for _, test := range tests { 483 t.Run(test.name, func(t *testing.T) { 484 pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) 485 originalPayload := stack.PayloadSince(pkt.TransportHeader()) 486 var reassembledPayload buffer.VectorisedView 487 pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) 488 for i := 0; ; i++ { 489 fragPkt, offset, copied, more := pf.BuildNextFragment() 490 wantFragment := test.wantFragments[i] 491 if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { 492 t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) 493 } 494 if copied != wantFragment.copied { 495 t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) 496 } 497 if offset != wantFragment.offset { 498 t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) 499 } 500 if more != wantFragment.more { 501 t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) 502 } 503 if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { 504 t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) 505 } 506 if got := fragPkt.AvailableHeaderBytes(); got != reserve { 507 t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) 508 } 509 if got := fragPkt.TransportHeader().View().Size(); got != 0 { 510 t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) 511 } 512 reassembledPayload.AppendViews(fragPkt.Data().Views()) 513 if !more { 514 if i != len(test.wantFragments)-1 { 515 t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) 516 } 517 break 518 } 519 } 520 if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload); diff != "" { 521 t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) 522 } 523 }) 524 } 525 } 526 527 type testTimeoutHandler struct { 528 pkt *stack.PacketBuffer 529 } 530 531 func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { 532 h.pkt = pkt 533 } 534 535 func TestTimeoutHandler(t *testing.T) { 536 const ( 537 proto = 99 538 ) 539 540 pk1 := pkt(1, "1") 541 pk2 := pkt(1, "2") 542 543 type processParam struct { 544 first uint16 545 last uint16 546 more bool 547 pkt *stack.PacketBuffer 548 } 549 550 tests := []struct { 551 name string 552 params []processParam 553 wantError bool 554 wantPkt *stack.PacketBuffer 555 }{ 556 { 557 name: "onTimeout runs", 558 params: []processParam{ 559 { 560 first: 0, 561 last: 0, 562 more: true, 563 pkt: pk1, 564 }, 565 }, 566 wantError: false, 567 wantPkt: pk1, 568 }, 569 { 570 name: "no first fragment", 571 params: []processParam{ 572 { 573 first: 1, 574 last: 1, 575 more: true, 576 pkt: pk1, 577 }, 578 }, 579 wantError: false, 580 wantPkt: nil, 581 }, 582 { 583 name: "second pkt is ignored", 584 params: []processParam{ 585 { 586 first: 0, 587 last: 0, 588 more: true, 589 pkt: pk1, 590 }, 591 { 592 first: 0, 593 last: 0, 594 more: true, 595 pkt: pk2, 596 }, 597 }, 598 wantError: false, 599 wantPkt: pk1, 600 }, 601 { 602 name: "invalid args - first is greater than last", 603 params: []processParam{ 604 { 605 first: 1, 606 last: 0, 607 more: true, 608 pkt: pk1, 609 }, 610 }, 611 wantError: true, 612 wantPkt: nil, 613 }, 614 } 615 616 id := FragmentID{ID: 0} 617 618 for _, test := range tests { 619 t.Run(test.name, func(t *testing.T) { 620 handler := &testTimeoutHandler{pkt: nil} 621 622 f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) 623 624 for _, p := range test.params { 625 if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { 626 t.Errorf("f.Process error = %s", err) 627 } 628 } 629 if !test.wantError { 630 r, ok := f.reassemblers[id] 631 if !ok { 632 t.Fatal("Reassembler not found") 633 } 634 f.release(r, true) 635 } 636 switch { 637 case handler.pkt != nil && test.wantPkt == nil: 638 t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data().AsRange().ToOwnedView()) 639 case handler.pkt == nil && test.wantPkt != nil: 640 t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data().AsRange().ToOwnedView()) 641 case handler.pkt != nil && test.wantPkt != nil: 642 if diff := cmp.Diff(test.wantPkt.Data().AsRange().ToOwnedView(), handler.pkt.Data().AsRange().ToOwnedView()); diff != "" { 643 t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) 644 } 645 } 646 }) 647 } 648 }