github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/internalutils/stream/stream_test.go (about) 1 /* 2 Copyright 2022 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package stream 18 19 import ( 20 "errors" 21 "fmt" 22 "io" 23 "strconv" 24 "testing" 25 "time" 26 27 "github.com/stretchr/testify/require" 28 ) 29 30 // TestSlice tests the slice stream. 31 func TestSlice(t *testing.T) { 32 t.Parallel() 33 34 // normal usage 35 s, err := Collect(Slice([]int{1, 2, 3})) 36 require.NoError(t, err) 37 require.Equal(t, []int{1, 2, 3}, s) 38 39 // single-element slice 40 s, err = Collect(Slice([]int{100})) 41 require.NoError(t, err) 42 require.Equal(t, []int{100}, s) 43 44 // nil slice 45 s, err = Collect(Slice[int](nil)) 46 require.NoError(t, err) 47 require.Empty(t, s) 48 } 49 50 // TestFilterMap tests the FilterMap combinator. 51 func TestFilterMap(t *testing.T) { 52 t.Parallel() 53 54 // normal usage 55 s, err := Collect(FilterMap(Slice([]int{1, 2, 3, 4}), func(i int) (string, bool) { 56 if i%2 == 0 { 57 return fmt.Sprintf("%d", i*10), true 58 } 59 return "", false 60 })) 61 require.NoError(t, err) 62 require.Equal(t, []string{"20", "40"}, s) 63 64 // single-match 65 s, err = Collect(FilterMap(Slice([]int{1, 2, 3, 4}), func(i int) (string, bool) { 66 if i == 3 { 67 return "three", true 68 } 69 return "", false 70 })) 71 require.NoError(t, err) 72 require.Equal(t, []string{"three"}, s) 73 74 // no matches 75 s, err = Collect(FilterMap(Slice([]int{1, 2, 3, 4}), func(i int) (string, bool) { 76 return "", false 77 })) 78 require.NoError(t, err) 79 require.Empty(t, s) 80 81 // empty stream 82 s, err = Collect(FilterMap(Empty[int](), func(_ int) (string, bool) { panic("unreachable") })) 83 require.NoError(t, err) 84 require.Empty(t, s) 85 86 // failure 87 err = Drain(FilterMap(Fail[int](fmt.Errorf("unexpected error")), func(_ int) (string, bool) { panic("unreachable") })) 88 require.Error(t, err) 89 } 90 91 // TestMapWhile tests the MapWhile combinator. 92 func TestMapWhile(t *testing.T) { 93 t.Parallel() 94 95 // normal usage 96 s, err := Collect(MapWhile(Slice([]int{1, 2, 3, 4}), func(i int) (string, bool) { 97 if i == 3 { 98 return "", false 99 } 100 return fmt.Sprintf("%d", i*10), true 101 })) 102 require.NoError(t, err) 103 require.Equal(t, []string{"10", "20"}, s) 104 105 // halt after 1 element 106 s, err = Collect(MapWhile(Slice([]int{1, 2, 3, 4}), func(i int) (string, bool) { 107 if i == 1 { 108 return "one", true 109 } 110 return "", false 111 })) 112 require.NoError(t, err) 113 require.Equal(t, []string{"one"}, s) 114 115 // halt immediately 116 s, err = Collect(MapWhile(Slice([]int{1, 2, 3, 4}), func(_ int) (string, bool) { 117 return "", false 118 })) 119 require.NoError(t, err) 120 require.Empty(t, s) 121 122 // empty stream 123 s, err = Collect(MapWhile(Empty[int](), func(_ int) (string, bool) { panic("unreachable") })) 124 require.NoError(t, err) 125 require.Empty(t, s) 126 127 // failure 128 err = Drain(MapWhile(Fail[int](fmt.Errorf("unexpected error")), func(_ int) (string, bool) { panic("unreachable") })) 129 require.Error(t, err) 130 } 131 132 // TestChain tests the Chain combinator. 133 func TestChain(t *testing.T) { 134 t.Parallel() 135 136 // normal usage 137 s, err := Collect(Chain( 138 Slice([]int{1, 2, 3}), 139 Slice([]int{4}), 140 Slice([]int{5, 6}), 141 )) 142 require.NoError(t, err) 143 require.Equal(t, []int{1, 2, 3, 4, 5, 6}, s) 144 145 // single substream 146 s, err = Collect(Chain(Slice([]int{1, 2, 3}))) 147 require.NoError(t, err) 148 require.Equal(t, []int{1, 2, 3}, s) 149 150 // no substreams 151 s, err = Collect(Chain[int]()) 152 require.NoError(t, err) 153 require.Empty(t, s) 154 155 // some empty substreams 156 s, err = Collect(Chain( 157 Empty[int](), 158 Slice([]int{4, 5, 6}), 159 Empty[int](), 160 )) 161 require.NoError(t, err) 162 require.Equal(t, []int{4, 5, 6}, s) 163 164 // all empty substreams 165 s, err = Collect(Chain( 166 Empty[int](), 167 Empty[int](), 168 )) 169 require.NoError(t, err) 170 require.Empty(t, s) 171 172 // late failure 173 s, err = Collect(Chain( 174 Slice([]int{7, 7, 7}), 175 Fail[int](fmt.Errorf("some error")), 176 )) 177 require.Error(t, err) 178 require.Equal(t, []int{7, 7, 7}, s) 179 180 // early failure 181 s, err = Collect(Chain( 182 Fail[int](fmt.Errorf("some other error")), 183 Func(func() (int, error) { panic("unreachable") }), 184 )) 185 require.Error(t, err) 186 require.Empty(t, s) 187 } 188 189 // TestFunc tests the Func stream. 190 func TestFunc(t *testing.T) { 191 t.Parallel() 192 193 // normal usage 194 var n int 195 s, err := Collect(Func(func() (int, error) { 196 n++ 197 if n > 3 { 198 return 0, io.EOF 199 } 200 return n, nil 201 })) 202 require.NoError(t, err) 203 require.Equal(t, []int{1, 2, 3}, s) 204 205 // single-element 206 var once bool 207 s, err = Collect(Func(func() (int, error) { 208 if once { 209 return 0, io.EOF 210 } 211 once = true 212 return 100, nil 213 })) 214 require.NoError(t, err) 215 require.Equal(t, []int{100}, s) 216 217 // no element 218 s, err = Collect(Func(func() (int, error) { 219 return 0, io.EOF 220 })) 221 require.NoError(t, err) 222 require.Empty(t, s) 223 224 // immediate error 225 err = Drain(Func(func() (int, error) { 226 return 0, fmt.Errorf("unexpected error") 227 })) 228 require.Error(t, err) 229 230 // error after a few streamations 231 n = 0 232 err = Drain(Func(func() (int, error) { 233 n++ 234 if n > 10 { 235 return 0, fmt.Errorf("unexpected error") 236 } 237 return n, nil 238 })) 239 require.Error(t, err) 240 } 241 242 func TestPageFunc(t *testing.T) { 243 t.Parallel() 244 245 // basic pages 246 var n int 247 s, err := Collect(PageFunc(func() ([]int, error) { 248 n++ 249 if n > 3 { 250 return nil, io.EOF 251 } 252 return []int{ 253 n, 254 n * 10, 255 n * 100, 256 }, nil 257 })) 258 require.NoError(t, err) 259 require.Equal(t, []int{1, 10, 100, 2, 20, 200, 3, 30, 300}, s) 260 261 // single page 262 var once bool 263 s, err = Collect(PageFunc(func() ([]int, error) { 264 if once { 265 return nil, io.EOF 266 } 267 once = true 268 return []int{1, 2, 3}, nil 269 })) 270 require.NoError(t, err) 271 require.Equal(t, []int{1, 2, 3}, s) 272 273 // single element 274 once = false 275 s, err = Collect(PageFunc(func() ([]int, error) { 276 if once { 277 return nil, io.EOF 278 } 279 once = true 280 return []int{100}, nil 281 })) 282 require.NoError(t, err) 283 require.Equal(t, []int{100}, s) 284 285 // no pages 286 s, err = Collect(PageFunc(func() ([]int, error) { 287 return nil, io.EOF 288 })) 289 require.NoError(t, err) 290 require.Empty(t, s) 291 292 // lots of empty pages 293 n = 0 294 s, err = Collect(PageFunc(func() ([]int, error) { 295 n++ 296 switch n { 297 case 5: 298 return []int{1, 2, 3}, nil 299 case 10: 300 return []int{4, 5, 6}, nil 301 case 15: 302 return nil, io.EOF 303 default: 304 return nil, nil 305 } 306 })) 307 require.NoError(t, err) 308 require.Equal(t, []int{1, 2, 3, 4, 5, 6}, s) 309 310 // only empty and/or nil pages 311 n = 0 312 s, err = Collect(PageFunc(func() ([]int, error) { 313 n++ 314 if n > 20 { 315 return nil, io.EOF 316 } 317 if n%2 == 0 { 318 return []int{}, nil 319 } 320 return nil, nil 321 })) 322 require.NoError(t, err) 323 require.Empty(t, s) 324 325 // eventual failure 326 n = 0 327 s, err = Collect(PageFunc(func() ([]int, error) { 328 n++ 329 if n > 3 { 330 return nil, fmt.Errorf("bad things") 331 } 332 return []int{1, 2, 3}, nil 333 })) 334 require.Error(t, err) 335 require.Equal(t, []int{1, 2, 3, 1, 2, 3, 1, 2, 3}, s) 336 337 // immediate failure 338 err = Drain(PageFunc(func() ([]int, error) { 339 return nil, fmt.Errorf("very bad things") 340 })) 341 require.Error(t, err) 342 } 343 344 // TestEmpty tests the Empty/Fail stream. 345 func TestEmpty(t *testing.T) { 346 t.Parallel() 347 348 // empty case 349 s, err := Collect(Empty[int]()) 350 require.NoError(t, err) 351 require.Empty(t, s) 352 353 // normal error case 354 s, err = Collect(Fail[int](fmt.Errorf("unexpected error"))) 355 require.Error(t, err) 356 require.Empty(t, s) 357 358 // nil error case 359 s, err = Collect(Fail[int](nil)) 360 require.NoError(t, err) 361 require.Empty(t, s) 362 } 363 364 // TestOnceFunc tests the OnceFunc stream combinator. 365 func TestOnceFunc(t *testing.T) { 366 t.Parallel() 367 368 // single-element variant 369 s, err := Collect(OnceFunc(func() (int, error) { 370 return 1, nil 371 })) 372 require.NoError(t, err) 373 require.Equal(t, []int{1}, s) 374 375 // empty stream case 376 s, err = Collect(OnceFunc(func() (int, error) { 377 return 1, io.EOF 378 })) 379 require.NoError(t, err) 380 require.Empty(t, s) 381 382 // error case 383 s, err = Collect(OnceFunc(func() (int, error) { 384 return 1, fmt.Errorf("unexpected error") 385 })) 386 require.Error(t, err) 387 require.Empty(t, s) 388 } 389 390 func TestCollectPages(t *testing.T) { 391 t.Parallel() 392 393 tts := []struct { 394 pages [][]string 395 expect []string 396 err error 397 desc string 398 }{ 399 { 400 pages: [][]string{ 401 {"foo", "bar"}, 402 {}, 403 {"bin", "baz"}, 404 }, 405 expect: []string{ 406 "foo", 407 "bar", 408 "bin", 409 "baz", 410 }, 411 desc: "basic-depagination", 412 }, 413 { 414 pages: [][]string{ 415 {"one"}, 416 }, 417 expect: []string{"one"}, 418 desc: "single-element-case", 419 }, 420 { 421 desc: "empty-case", 422 }, 423 { 424 err: fmt.Errorf("failure"), 425 desc: "error-case", 426 }, 427 } 428 429 for _, tt := range tts { 430 t.Run(tt.desc, func(t *testing.T) { 431 var stream Stream[[]string] 432 if tt.err == nil { 433 stream = Slice(tt.pages) 434 } else { 435 stream = Fail[[]string](tt.err) 436 } 437 collected, err := CollectPages(stream) 438 if tt.err == nil { 439 require.NoError(t, err) 440 } else { 441 require.Error(t, err) 442 } 443 if len(tt.expect) == 0 { 444 require.Empty(t, collected) 445 } else { 446 require.Equal(t, tt.expect, collected) 447 } 448 }) 449 } 450 } 451 452 func TestTake(t *testing.T) { 453 t.Parallel() 454 455 intSlice := func(n int) []int { 456 s := make([]int, 0, n) 457 for i := 0; i < n; i++ { 458 s = append(s, i) 459 } 460 return s 461 } 462 463 tests := []struct { 464 name string 465 input []int 466 n int 467 expectedOutput []int 468 expectMore bool 469 }{ 470 { 471 name: "empty stream", 472 input: []int{}, 473 n: 10, 474 expectedOutput: []int{}, 475 expectMore: false, 476 }, 477 { 478 name: "full stream", 479 input: intSlice(20), 480 n: 10, 481 expectedOutput: intSlice(10), 482 expectMore: true, 483 }, 484 { 485 name: "drain stream of size n", 486 input: intSlice(10), 487 n: 10, 488 expectedOutput: intSlice(10), 489 expectMore: true, 490 }, 491 { 492 name: "drain stream of size < n", 493 input: intSlice(5), 494 n: 10, 495 expectedOutput: intSlice(5), 496 expectMore: false, 497 }, 498 } 499 for _, tc := range tests { 500 t.Run(tc.name, func(t *testing.T) { 501 stream := Slice(tc.input) 502 output, more := Take(stream, tc.n) 503 require.Equal(t, tc.expectedOutput, output) 504 require.Equal(t, tc.expectMore, more) 505 }) 506 } 507 } 508 509 // TestRateLimitFailure verifies the expected failure conditions of the RateLimit helper. 510 func TestRateLimitFailure(t *testing.T) { 511 t.Parallel() 512 513 var limiterError = errors.New("limiter-error") 514 var streamError = errors.New("stream-error") 515 516 tts := []struct { 517 desc string 518 items int 519 stream error 520 limiter error 521 expect error 522 }{ 523 { 524 desc: "simultaneous", 525 stream: streamError, 526 limiter: limiterError, 527 expect: streamError, 528 }, 529 { 530 desc: "stream-only", 531 stream: streamError, 532 expect: streamError, 533 }, 534 { 535 desc: "limiter-only", 536 limiter: limiterError, 537 expect: limiterError, 538 }, 539 { 540 desc: "limiter-graceful", 541 limiter: io.EOF, 542 expect: nil, 543 }, 544 } 545 546 for _, tt := range tts { 547 t.Run(tt.desc, func(t *testing.T) { 548 err := Drain(RateLimit(Fail[int](tt.stream), func() error { return tt.limiter })) 549 if tt.expect == nil { 550 require.NoError(t, err) 551 return 552 } 553 554 require.ErrorIs(t, err, tt.expect) 555 }) 556 } 557 } 558 559 // TestRateLimit sets up a concurrent channel-based limiter and verifies its effect on a pool of workers consuming 560 // items from streams. 561 func TestRateLimit(t *testing.T) { 562 t.Parallel() 563 564 const workers = 16 565 const maxItemsPerWorker = 16 566 const tokens = 100 567 const burst = 10 568 569 lim := make(chan struct{}, burst) 570 done := make(chan struct{}) 571 572 results := make(chan error, workers) 573 574 items := make(chan struct{}, tokens+1) 575 576 for i := 0; i < workers; i++ { 577 go func() { 578 stream := RateLimit(repeat("some-item", maxItemsPerWorker), func() error { 579 select { 580 case <-lim: 581 return nil 582 case <-done: 583 // make sure we still consume remaining tokens even if 'done' is closed (simplifies 584 // test logic by letting us close 'done' immediately after sending last token without 585 // worrying about racing). 586 select { 587 case <-lim: 588 return nil 589 default: 590 return io.EOF 591 } 592 } 593 }) 594 595 for stream.Next() { 596 items <- struct{}{} 597 } 598 599 results <- stream.Done() 600 }() 601 } 602 603 // yielded tracks total number of tokens yielded on limiter channel 604 var yielded int 605 606 // do an initial fill of limiter channel 607 for i := 0; i < burst; i++ { 608 select { 609 case lim <- struct{}{}: 610 yielded++ 611 default: 612 require.FailNow(t, "initial burst should never block") 613 } 614 } 615 616 var consumed int 617 618 // consume item receipt events 619 timeoutC := time.After(time.Second * 30) 620 for i := 0; i < burst; i++ { 621 select { 622 case <-items: 623 consumed++ 624 case <-timeoutC: 625 require.FailNow(t, "timeout waiting for item") 626 } 627 } 628 629 // ensure no more items available 630 select { 631 case <-items: 632 require.FailNow(t, "received item without corresponding token yield") 633 default: 634 } 635 636 // yield the rest of the tokens 637 for yielded < tokens { 638 select { 639 case lim <- struct{}{}: 640 yielded++ 641 case <-timeoutC: 642 require.FailNow(t, "timeout waiting to yield token") 643 } 644 } 645 646 // signal workers that they should exit once remaining tokens 647 // are consumed. 648 close(done) 649 650 // wait for all workers to finish 651 for i := 0; i < workers; i++ { 652 select { 653 case err := <-results: 654 require.NoError(t, err) 655 case <-timeoutC: 656 require.FailNow(t, "timeout waiting for worker to exit") 657 } 658 } 659 660 // consume the rest of the item events 661 ConsumeItems: 662 for { 663 select { 664 case <-items: 665 consumed++ 666 default: 667 break ConsumeItems 668 } 669 } 670 671 // note that total number of processed items may actually vary since we are rate-limiting 672 // how frequently a stream is *polled*, not how frequently it yields an item. A stream being 673 // polled may result in us discovering that it is empty, in which case a limiter token is still 674 // consumed, but no item is yielded. 675 require.LessOrEqual(t, consumed, tokens) 676 require.GreaterOrEqual(t, consumed, tokens-workers) 677 } 678 679 // repeat repeats the same item N times 680 func repeat[T any](item T, count int) Stream[T] { 681 var n int 682 return Func(func() (T, error) { 683 n++ 684 if n > count { 685 var zero T 686 return zero, io.EOF 687 } 688 return item, nil 689 }) 690 } 691 692 // TestMergeStreams tests the MergeStreams adapter. 693 func TestMergeStreams(t *testing.T) { 694 t.Parallel() 695 696 // Mock convert function that converts the strings in streamB to integers. 697 convertBFunc := func(val string) int { 698 bValue, _ := strconv.Atoi(val) 699 return bValue 700 } 701 702 // Since streamA is already the type we want from the merged stream, the convertA function just returns the item as-is. 703 convertAFunc := func(item int) int { return item } 704 705 // Mock compare function that favors the lower value. 706 compareFunc := func(a int, b string) bool { 707 return a <= convertBFunc(b) 708 } 709 710 // Test the case where the streams should have interlaced values. 711 t.Run("interlaced streams", func(t *testing.T) { 712 streamA := Slice([]int{1, 3, 5}) 713 streamB := Slice([]string{"2", "4", "6"}) 714 715 resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc) 716 out, err := Collect(resultStream) 717 718 require.NoError(t, err) 719 require.Equal(t, []int{1, 2, 3, 4, 5, 6}, out) 720 721 err = resultStream.Done() 722 require.NoError(t, err) 723 }) 724 725 // Test the case where streamA is empty. 726 t.Run("stream A empty", func(t *testing.T) { 727 streamA := Empty[int]() 728 streamB := Slice([]string{"1", "2", "3"}) 729 730 resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc) 731 out, err := Collect(resultStream) 732 733 require.NoError(t, err) 734 require.Equal(t, []int{1, 2, 3}, out) 735 736 err = resultStream.Done() 737 require.NoError(t, err) 738 }) 739 740 // Test the case where streamB is empty. 741 t.Run("stream B empty", func(t *testing.T) { 742 streamA := Slice([]int{1, 2, 3}) 743 streamB := Empty[string]() 744 745 resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc) 746 out, err := Collect(resultStream) 747 748 require.NoError(t, err) 749 require.Equal(t, []int{1, 2, 3}, out) 750 751 err = resultStream.Done() 752 require.NoError(t, err) 753 }) 754 755 // Test the case where both streams are empty. 756 t.Run("both streams empty", func(t *testing.T) { 757 streamA := Empty[int]() 758 streamB := Empty[string]() 759 760 resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc) 761 out, err := Collect(resultStream) 762 763 require.NoError(t, err) 764 require.Empty(t, out) 765 766 err = resultStream.Done() 767 require.NoError(t, err) 768 }) 769 770 // Test the case where every value in streamA is lower than every value in streamB. 771 t.Run("compare always favors A", func(t *testing.T) { 772 streamA := Slice([]int{1, 2, 3}) 773 streamB := Slice([]string{"4", "5", "6"}) 774 775 resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc) 776 out, err := Collect(resultStream) 777 778 require.NoError(t, err) 779 require.Equal(t, []int{1, 2, 3, 4, 5, 6}, out) 780 781 err = resultStream.Done() 782 require.NoError(t, err) 783 }) 784 785 // Test the case where every value in streamB is lower than every value in streamA. 786 t.Run("compare always favors B", func(t *testing.T) { 787 streamA := Slice([]int{4, 5, 6}) 788 streamB := Slice([]string{"1", "2", "3"}) 789 790 resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc) 791 out, err := Collect(resultStream) 792 793 require.NoError(t, err) 794 require.Equal(t, []int{1, 2, 3, 4, 5, 6}, out) 795 796 err = resultStream.Done() 797 require.NoError(t, err) 798 }) 799 }