github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/internalutils/stream/stream_test.go (about)

     1  /*
     2  Copyright 2022 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package stream
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"strconv"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  // TestSlice tests the slice stream.
    31  func TestSlice(t *testing.T) {
    32  	t.Parallel()
    33  
    34  	// normal usage
    35  	s, err := Collect(Slice([]int{1, 2, 3}))
    36  	require.NoError(t, err)
    37  	require.Equal(t, []int{1, 2, 3}, s)
    38  
    39  	// single-element slice
    40  	s, err = Collect(Slice([]int{100}))
    41  	require.NoError(t, err)
    42  	require.Equal(t, []int{100}, s)
    43  
    44  	// nil slice
    45  	s, err = Collect(Slice[int](nil))
    46  	require.NoError(t, err)
    47  	require.Empty(t, s)
    48  }
    49  
    50  // TestFilterMap tests the FilterMap combinator.
    51  func TestFilterMap(t *testing.T) {
    52  	t.Parallel()
    53  
    54  	// normal usage
    55  	s, err := Collect(FilterMap(Slice([]int{1, 2, 3, 4}), func(i int) (string, bool) {
    56  		if i%2 == 0 {
    57  			return fmt.Sprintf("%d", i*10), true
    58  		}
    59  		return "", false
    60  	}))
    61  	require.NoError(t, err)
    62  	require.Equal(t, []string{"20", "40"}, s)
    63  
    64  	// single-match
    65  	s, err = Collect(FilterMap(Slice([]int{1, 2, 3, 4}), func(i int) (string, bool) {
    66  		if i == 3 {
    67  			return "three", true
    68  		}
    69  		return "", false
    70  	}))
    71  	require.NoError(t, err)
    72  	require.Equal(t, []string{"three"}, s)
    73  
    74  	// no matches
    75  	s, err = Collect(FilterMap(Slice([]int{1, 2, 3, 4}), func(i int) (string, bool) {
    76  		return "", false
    77  	}))
    78  	require.NoError(t, err)
    79  	require.Empty(t, s)
    80  
    81  	// empty stream
    82  	s, err = Collect(FilterMap(Empty[int](), func(_ int) (string, bool) { panic("unreachable") }))
    83  	require.NoError(t, err)
    84  	require.Empty(t, s)
    85  
    86  	// failure
    87  	err = Drain(FilterMap(Fail[int](fmt.Errorf("unexpected error")), func(_ int) (string, bool) { panic("unreachable") }))
    88  	require.Error(t, err)
    89  }
    90  
    91  // TestMapWhile tests the MapWhile combinator.
    92  func TestMapWhile(t *testing.T) {
    93  	t.Parallel()
    94  
    95  	// normal usage
    96  	s, err := Collect(MapWhile(Slice([]int{1, 2, 3, 4}), func(i int) (string, bool) {
    97  		if i == 3 {
    98  			return "", false
    99  		}
   100  		return fmt.Sprintf("%d", i*10), true
   101  	}))
   102  	require.NoError(t, err)
   103  	require.Equal(t, []string{"10", "20"}, s)
   104  
   105  	// halt after 1 element
   106  	s, err = Collect(MapWhile(Slice([]int{1, 2, 3, 4}), func(i int) (string, bool) {
   107  		if i == 1 {
   108  			return "one", true
   109  		}
   110  		return "", false
   111  	}))
   112  	require.NoError(t, err)
   113  	require.Equal(t, []string{"one"}, s)
   114  
   115  	// halt immediately
   116  	s, err = Collect(MapWhile(Slice([]int{1, 2, 3, 4}), func(_ int) (string, bool) {
   117  		return "", false
   118  	}))
   119  	require.NoError(t, err)
   120  	require.Empty(t, s)
   121  
   122  	// empty stream
   123  	s, err = Collect(MapWhile(Empty[int](), func(_ int) (string, bool) { panic("unreachable") }))
   124  	require.NoError(t, err)
   125  	require.Empty(t, s)
   126  
   127  	// failure
   128  	err = Drain(MapWhile(Fail[int](fmt.Errorf("unexpected error")), func(_ int) (string, bool) { panic("unreachable") }))
   129  	require.Error(t, err)
   130  }
   131  
   132  // TestChain tests the Chain combinator.
   133  func TestChain(t *testing.T) {
   134  	t.Parallel()
   135  
   136  	// normal usage
   137  	s, err := Collect(Chain(
   138  		Slice([]int{1, 2, 3}),
   139  		Slice([]int{4}),
   140  		Slice([]int{5, 6}),
   141  	))
   142  	require.NoError(t, err)
   143  	require.Equal(t, []int{1, 2, 3, 4, 5, 6}, s)
   144  
   145  	// single substream
   146  	s, err = Collect(Chain(Slice([]int{1, 2, 3})))
   147  	require.NoError(t, err)
   148  	require.Equal(t, []int{1, 2, 3}, s)
   149  
   150  	// no substreams
   151  	s, err = Collect(Chain[int]())
   152  	require.NoError(t, err)
   153  	require.Empty(t, s)
   154  
   155  	// some empty substreams
   156  	s, err = Collect(Chain(
   157  		Empty[int](),
   158  		Slice([]int{4, 5, 6}),
   159  		Empty[int](),
   160  	))
   161  	require.NoError(t, err)
   162  	require.Equal(t, []int{4, 5, 6}, s)
   163  
   164  	// all empty substreams
   165  	s, err = Collect(Chain(
   166  		Empty[int](),
   167  		Empty[int](),
   168  	))
   169  	require.NoError(t, err)
   170  	require.Empty(t, s)
   171  
   172  	// late failure
   173  	s, err = Collect(Chain(
   174  		Slice([]int{7, 7, 7}),
   175  		Fail[int](fmt.Errorf("some error")),
   176  	))
   177  	require.Error(t, err)
   178  	require.Equal(t, []int{7, 7, 7}, s)
   179  
   180  	// early failure
   181  	s, err = Collect(Chain(
   182  		Fail[int](fmt.Errorf("some other error")),
   183  		Func(func() (int, error) { panic("unreachable") }),
   184  	))
   185  	require.Error(t, err)
   186  	require.Empty(t, s)
   187  }
   188  
   189  // TestFunc tests the Func stream.
   190  func TestFunc(t *testing.T) {
   191  	t.Parallel()
   192  
   193  	// normal usage
   194  	var n int
   195  	s, err := Collect(Func(func() (int, error) {
   196  		n++
   197  		if n > 3 {
   198  			return 0, io.EOF
   199  		}
   200  		return n, nil
   201  	}))
   202  	require.NoError(t, err)
   203  	require.Equal(t, []int{1, 2, 3}, s)
   204  
   205  	// single-element
   206  	var once bool
   207  	s, err = Collect(Func(func() (int, error) {
   208  		if once {
   209  			return 0, io.EOF
   210  		}
   211  		once = true
   212  		return 100, nil
   213  	}))
   214  	require.NoError(t, err)
   215  	require.Equal(t, []int{100}, s)
   216  
   217  	// no element
   218  	s, err = Collect(Func(func() (int, error) {
   219  		return 0, io.EOF
   220  	}))
   221  	require.NoError(t, err)
   222  	require.Empty(t, s)
   223  
   224  	// immediate error
   225  	err = Drain(Func(func() (int, error) {
   226  		return 0, fmt.Errorf("unexpected error")
   227  	}))
   228  	require.Error(t, err)
   229  
   230  	// error after a few streamations
   231  	n = 0
   232  	err = Drain(Func(func() (int, error) {
   233  		n++
   234  		if n > 10 {
   235  			return 0, fmt.Errorf("unexpected error")
   236  		}
   237  		return n, nil
   238  	}))
   239  	require.Error(t, err)
   240  }
   241  
   242  func TestPageFunc(t *testing.T) {
   243  	t.Parallel()
   244  
   245  	// basic pages
   246  	var n int
   247  	s, err := Collect(PageFunc(func() ([]int, error) {
   248  		n++
   249  		if n > 3 {
   250  			return nil, io.EOF
   251  		}
   252  		return []int{
   253  			n,
   254  			n * 10,
   255  			n * 100,
   256  		}, nil
   257  	}))
   258  	require.NoError(t, err)
   259  	require.Equal(t, []int{1, 10, 100, 2, 20, 200, 3, 30, 300}, s)
   260  
   261  	// single page
   262  	var once bool
   263  	s, err = Collect(PageFunc(func() ([]int, error) {
   264  		if once {
   265  			return nil, io.EOF
   266  		}
   267  		once = true
   268  		return []int{1, 2, 3}, nil
   269  	}))
   270  	require.NoError(t, err)
   271  	require.Equal(t, []int{1, 2, 3}, s)
   272  
   273  	// single element
   274  	once = false
   275  	s, err = Collect(PageFunc(func() ([]int, error) {
   276  		if once {
   277  			return nil, io.EOF
   278  		}
   279  		once = true
   280  		return []int{100}, nil
   281  	}))
   282  	require.NoError(t, err)
   283  	require.Equal(t, []int{100}, s)
   284  
   285  	// no pages
   286  	s, err = Collect(PageFunc(func() ([]int, error) {
   287  		return nil, io.EOF
   288  	}))
   289  	require.NoError(t, err)
   290  	require.Empty(t, s)
   291  
   292  	// lots of empty pages
   293  	n = 0
   294  	s, err = Collect(PageFunc(func() ([]int, error) {
   295  		n++
   296  		switch n {
   297  		case 5:
   298  			return []int{1, 2, 3}, nil
   299  		case 10:
   300  			return []int{4, 5, 6}, nil
   301  		case 15:
   302  			return nil, io.EOF
   303  		default:
   304  			return nil, nil
   305  		}
   306  	}))
   307  	require.NoError(t, err)
   308  	require.Equal(t, []int{1, 2, 3, 4, 5, 6}, s)
   309  
   310  	// only empty and/or nil pages
   311  	n = 0
   312  	s, err = Collect(PageFunc(func() ([]int, error) {
   313  		n++
   314  		if n > 20 {
   315  			return nil, io.EOF
   316  		}
   317  		if n%2 == 0 {
   318  			return []int{}, nil
   319  		}
   320  		return nil, nil
   321  	}))
   322  	require.NoError(t, err)
   323  	require.Empty(t, s)
   324  
   325  	// eventual failure
   326  	n = 0
   327  	s, err = Collect(PageFunc(func() ([]int, error) {
   328  		n++
   329  		if n > 3 {
   330  			return nil, fmt.Errorf("bad things")
   331  		}
   332  		return []int{1, 2, 3}, nil
   333  	}))
   334  	require.Error(t, err)
   335  	require.Equal(t, []int{1, 2, 3, 1, 2, 3, 1, 2, 3}, s)
   336  
   337  	// immediate failure
   338  	err = Drain(PageFunc(func() ([]int, error) {
   339  		return nil, fmt.Errorf("very bad things")
   340  	}))
   341  	require.Error(t, err)
   342  }
   343  
   344  // TestEmpty tests the Empty/Fail stream.
   345  func TestEmpty(t *testing.T) {
   346  	t.Parallel()
   347  
   348  	// empty case
   349  	s, err := Collect(Empty[int]())
   350  	require.NoError(t, err)
   351  	require.Empty(t, s)
   352  
   353  	// normal error case
   354  	s, err = Collect(Fail[int](fmt.Errorf("unexpected error")))
   355  	require.Error(t, err)
   356  	require.Empty(t, s)
   357  
   358  	// nil error case
   359  	s, err = Collect(Fail[int](nil))
   360  	require.NoError(t, err)
   361  	require.Empty(t, s)
   362  }
   363  
   364  // TestOnceFunc tests the OnceFunc stream combinator.
   365  func TestOnceFunc(t *testing.T) {
   366  	t.Parallel()
   367  
   368  	// single-element variant
   369  	s, err := Collect(OnceFunc(func() (int, error) {
   370  		return 1, nil
   371  	}))
   372  	require.NoError(t, err)
   373  	require.Equal(t, []int{1}, s)
   374  
   375  	// empty stream case
   376  	s, err = Collect(OnceFunc(func() (int, error) {
   377  		return 1, io.EOF
   378  	}))
   379  	require.NoError(t, err)
   380  	require.Empty(t, s)
   381  
   382  	// error case
   383  	s, err = Collect(OnceFunc(func() (int, error) {
   384  		return 1, fmt.Errorf("unexpected error")
   385  	}))
   386  	require.Error(t, err)
   387  	require.Empty(t, s)
   388  }
   389  
   390  func TestCollectPages(t *testing.T) {
   391  	t.Parallel()
   392  
   393  	tts := []struct {
   394  		pages  [][]string
   395  		expect []string
   396  		err    error
   397  		desc   string
   398  	}{
   399  		{
   400  			pages: [][]string{
   401  				{"foo", "bar"},
   402  				{},
   403  				{"bin", "baz"},
   404  			},
   405  			expect: []string{
   406  				"foo",
   407  				"bar",
   408  				"bin",
   409  				"baz",
   410  			},
   411  			desc: "basic-depagination",
   412  		},
   413  		{
   414  			pages: [][]string{
   415  				{"one"},
   416  			},
   417  			expect: []string{"one"},
   418  			desc:   "single-element-case",
   419  		},
   420  		{
   421  			desc: "empty-case",
   422  		},
   423  		{
   424  			err:  fmt.Errorf("failure"),
   425  			desc: "error-case",
   426  		},
   427  	}
   428  
   429  	for _, tt := range tts {
   430  		t.Run(tt.desc, func(t *testing.T) {
   431  			var stream Stream[[]string]
   432  			if tt.err == nil {
   433  				stream = Slice(tt.pages)
   434  			} else {
   435  				stream = Fail[[]string](tt.err)
   436  			}
   437  			collected, err := CollectPages(stream)
   438  			if tt.err == nil {
   439  				require.NoError(t, err)
   440  			} else {
   441  				require.Error(t, err)
   442  			}
   443  			if len(tt.expect) == 0 {
   444  				require.Empty(t, collected)
   445  			} else {
   446  				require.Equal(t, tt.expect, collected)
   447  			}
   448  		})
   449  	}
   450  }
   451  
   452  func TestTake(t *testing.T) {
   453  	t.Parallel()
   454  
   455  	intSlice := func(n int) []int {
   456  		s := make([]int, 0, n)
   457  		for i := 0; i < n; i++ {
   458  			s = append(s, i)
   459  		}
   460  		return s
   461  	}
   462  
   463  	tests := []struct {
   464  		name           string
   465  		input          []int
   466  		n              int
   467  		expectedOutput []int
   468  		expectMore     bool
   469  	}{
   470  		{
   471  			name:           "empty stream",
   472  			input:          []int{},
   473  			n:              10,
   474  			expectedOutput: []int{},
   475  			expectMore:     false,
   476  		},
   477  		{
   478  			name:           "full stream",
   479  			input:          intSlice(20),
   480  			n:              10,
   481  			expectedOutput: intSlice(10),
   482  			expectMore:     true,
   483  		},
   484  		{
   485  			name:           "drain stream of size n",
   486  			input:          intSlice(10),
   487  			n:              10,
   488  			expectedOutput: intSlice(10),
   489  			expectMore:     true,
   490  		},
   491  		{
   492  			name:           "drain stream of size < n",
   493  			input:          intSlice(5),
   494  			n:              10,
   495  			expectedOutput: intSlice(5),
   496  			expectMore:     false,
   497  		},
   498  	}
   499  	for _, tc := range tests {
   500  		t.Run(tc.name, func(t *testing.T) {
   501  			stream := Slice(tc.input)
   502  			output, more := Take(stream, tc.n)
   503  			require.Equal(t, tc.expectedOutput, output)
   504  			require.Equal(t, tc.expectMore, more)
   505  		})
   506  	}
   507  }
   508  
   509  // TestRateLimitFailure verifies the expected failure conditions of the RateLimit helper.
   510  func TestRateLimitFailure(t *testing.T) {
   511  	t.Parallel()
   512  
   513  	var limiterError = errors.New("limiter-error")
   514  	var streamError = errors.New("stream-error")
   515  
   516  	tts := []struct {
   517  		desc    string
   518  		items   int
   519  		stream  error
   520  		limiter error
   521  		expect  error
   522  	}{
   523  		{
   524  			desc:    "simultaneous",
   525  			stream:  streamError,
   526  			limiter: limiterError,
   527  			expect:  streamError,
   528  		},
   529  		{
   530  			desc:   "stream-only",
   531  			stream: streamError,
   532  			expect: streamError,
   533  		},
   534  		{
   535  			desc:    "limiter-only",
   536  			limiter: limiterError,
   537  			expect:  limiterError,
   538  		},
   539  		{
   540  			desc:    "limiter-graceful",
   541  			limiter: io.EOF,
   542  			expect:  nil,
   543  		},
   544  	}
   545  
   546  	for _, tt := range tts {
   547  		t.Run(tt.desc, func(t *testing.T) {
   548  			err := Drain(RateLimit(Fail[int](tt.stream), func() error { return tt.limiter }))
   549  			if tt.expect == nil {
   550  				require.NoError(t, err)
   551  				return
   552  			}
   553  
   554  			require.ErrorIs(t, err, tt.expect)
   555  		})
   556  	}
   557  }
   558  
   559  // TestRateLimit sets up a concurrent channel-based limiter and verifies its effect on a pool of workers consuming
   560  // items from streams.
   561  func TestRateLimit(t *testing.T) {
   562  	t.Parallel()
   563  
   564  	const workers = 16
   565  	const maxItemsPerWorker = 16
   566  	const tokens = 100
   567  	const burst = 10
   568  
   569  	lim := make(chan struct{}, burst)
   570  	done := make(chan struct{})
   571  
   572  	results := make(chan error, workers)
   573  
   574  	items := make(chan struct{}, tokens+1)
   575  
   576  	for i := 0; i < workers; i++ {
   577  		go func() {
   578  			stream := RateLimit(repeat("some-item", maxItemsPerWorker), func() error {
   579  				select {
   580  				case <-lim:
   581  					return nil
   582  				case <-done:
   583  					// make sure we still consume remaining tokens even if 'done' is closed (simplifies
   584  					// test logic by letting us close 'done' immediately after sending last token without
   585  					// worrying about racing).
   586  					select {
   587  					case <-lim:
   588  						return nil
   589  					default:
   590  						return io.EOF
   591  					}
   592  				}
   593  			})
   594  
   595  			for stream.Next() {
   596  				items <- struct{}{}
   597  			}
   598  
   599  			results <- stream.Done()
   600  		}()
   601  	}
   602  
   603  	// yielded tracks total number of tokens yielded on limiter channel
   604  	var yielded int
   605  
   606  	// do an initial fill of limiter channel
   607  	for i := 0; i < burst; i++ {
   608  		select {
   609  		case lim <- struct{}{}:
   610  			yielded++
   611  		default:
   612  			require.FailNow(t, "initial burst should never block")
   613  		}
   614  	}
   615  
   616  	var consumed int
   617  
   618  	// consume item receipt events
   619  	timeoutC := time.After(time.Second * 30)
   620  	for i := 0; i < burst; i++ {
   621  		select {
   622  		case <-items:
   623  			consumed++
   624  		case <-timeoutC:
   625  			require.FailNow(t, "timeout waiting for item")
   626  		}
   627  	}
   628  
   629  	// ensure no more items available
   630  	select {
   631  	case <-items:
   632  		require.FailNow(t, "received item without corresponding token yield")
   633  	default:
   634  	}
   635  
   636  	// yield the rest of the tokens
   637  	for yielded < tokens {
   638  		select {
   639  		case lim <- struct{}{}:
   640  			yielded++
   641  		case <-timeoutC:
   642  			require.FailNow(t, "timeout waiting to yield token")
   643  		}
   644  	}
   645  
   646  	// signal workers that they should exit once remaining tokens
   647  	// are consumed.
   648  	close(done)
   649  
   650  	// wait for all workers to finish
   651  	for i := 0; i < workers; i++ {
   652  		select {
   653  		case err := <-results:
   654  			require.NoError(t, err)
   655  		case <-timeoutC:
   656  			require.FailNow(t, "timeout waiting for worker to exit")
   657  		}
   658  	}
   659  
   660  	// consume the rest of the item events
   661  ConsumeItems:
   662  	for {
   663  		select {
   664  		case <-items:
   665  			consumed++
   666  		default:
   667  			break ConsumeItems
   668  		}
   669  	}
   670  
   671  	// note that total number of processed items may actually vary since we are rate-limiting
   672  	// how frequently a stream is *polled*, not how frequently it yields an item. A stream being
   673  	// polled may result in us discovering that it is empty, in which case a limiter token is still
   674  	// consumed, but no item is yielded.
   675  	require.LessOrEqual(t, consumed, tokens)
   676  	require.GreaterOrEqual(t, consumed, tokens-workers)
   677  }
   678  
   679  // repeat repeats the same item N times
   680  func repeat[T any](item T, count int) Stream[T] {
   681  	var n int
   682  	return Func(func() (T, error) {
   683  		n++
   684  		if n > count {
   685  			var zero T
   686  			return zero, io.EOF
   687  		}
   688  		return item, nil
   689  	})
   690  }
   691  
   692  // TestMergeStreams tests the MergeStreams adapter.
   693  func TestMergeStreams(t *testing.T) {
   694  	t.Parallel()
   695  
   696  	// Mock convert function that converts the strings in streamB to integers.
   697  	convertBFunc := func(val string) int {
   698  		bValue, _ := strconv.Atoi(val)
   699  		return bValue
   700  	}
   701  
   702  	// Since streamA is already the type we want from the merged stream, the convertA function just returns the item as-is.
   703  	convertAFunc := func(item int) int { return item }
   704  
   705  	// Mock compare function that favors the lower value.
   706  	compareFunc := func(a int, b string) bool {
   707  		return a <= convertBFunc(b)
   708  	}
   709  
   710  	// Test the case where the streams should have interlaced values.
   711  	t.Run("interlaced streams", func(t *testing.T) {
   712  		streamA := Slice([]int{1, 3, 5})
   713  		streamB := Slice([]string{"2", "4", "6"})
   714  
   715  		resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc)
   716  		out, err := Collect(resultStream)
   717  
   718  		require.NoError(t, err)
   719  		require.Equal(t, []int{1, 2, 3, 4, 5, 6}, out)
   720  
   721  		err = resultStream.Done()
   722  		require.NoError(t, err)
   723  	})
   724  
   725  	// Test the case where streamA is empty.
   726  	t.Run("stream A empty", func(t *testing.T) {
   727  		streamA := Empty[int]()
   728  		streamB := Slice([]string{"1", "2", "3"})
   729  
   730  		resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc)
   731  		out, err := Collect(resultStream)
   732  
   733  		require.NoError(t, err)
   734  		require.Equal(t, []int{1, 2, 3}, out)
   735  
   736  		err = resultStream.Done()
   737  		require.NoError(t, err)
   738  	})
   739  
   740  	// Test the case where streamB is empty.
   741  	t.Run("stream B empty", func(t *testing.T) {
   742  		streamA := Slice([]int{1, 2, 3})
   743  		streamB := Empty[string]()
   744  
   745  		resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc)
   746  		out, err := Collect(resultStream)
   747  
   748  		require.NoError(t, err)
   749  		require.Equal(t, []int{1, 2, 3}, out)
   750  
   751  		err = resultStream.Done()
   752  		require.NoError(t, err)
   753  	})
   754  
   755  	// Test the case where both streams are empty.
   756  	t.Run("both streams empty", func(t *testing.T) {
   757  		streamA := Empty[int]()
   758  		streamB := Empty[string]()
   759  
   760  		resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc)
   761  		out, err := Collect(resultStream)
   762  
   763  		require.NoError(t, err)
   764  		require.Empty(t, out)
   765  
   766  		err = resultStream.Done()
   767  		require.NoError(t, err)
   768  	})
   769  
   770  	// Test the case where every value in streamA is lower than every value in streamB.
   771  	t.Run("compare always favors A", func(t *testing.T) {
   772  		streamA := Slice([]int{1, 2, 3})
   773  		streamB := Slice([]string{"4", "5", "6"})
   774  
   775  		resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc)
   776  		out, err := Collect(resultStream)
   777  
   778  		require.NoError(t, err)
   779  		require.Equal(t, []int{1, 2, 3, 4, 5, 6}, out)
   780  
   781  		err = resultStream.Done()
   782  		require.NoError(t, err)
   783  	})
   784  
   785  	// Test the case where every value in streamB is lower than every value in streamA.
   786  	t.Run("compare always favors B", func(t *testing.T) {
   787  		streamA := Slice([]int{4, 5, 6})
   788  		streamB := Slice([]string{"1", "2", "3"})
   789  
   790  		resultStream := MergeStreams(streamA, streamB, compareFunc, convertAFunc, convertBFunc)
   791  		out, err := Collect(resultStream)
   792  
   793  		require.NoError(t, err)
   794  		require.Equal(t, []int{1, 2, 3, 4, 5, 6}, out)
   795  
   796  		err = resultStream.Done()
   797  		require.NoError(t, err)
   798  	})
   799  }