github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/colexec/routers_test.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package colexec
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  	"math/rand"
    17  	"sync"
    18  	"sync/atomic"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/cockroachdb/cockroach/pkg/col/coldata"
    23  	"github.com/cockroachdb/cockroach/pkg/col/coldatatestutils"
    24  	"github.com/cockroachdb/cockroach/pkg/sql/colcontainer"
    25  	"github.com/cockroachdb/cockroach/pkg/sql/colexecbase"
    26  	"github.com/cockroachdb/cockroach/pkg/sql/colmem"
    27  	"github.com/cockroachdb/cockroach/pkg/sql/execinfrapb"
    28  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    29  	"github.com/cockroachdb/cockroach/pkg/testutils"
    30  	"github.com/cockroachdb/cockroach/pkg/testutils/colcontainerutils"
    31  	"github.com/cockroachdb/cockroach/pkg/util/humanizeutil"
    32  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    33  	"github.com/cockroachdb/cockroach/pkg/util/mon"
    34  	"github.com/cockroachdb/cockroach/pkg/util/randutil"
    35  	"github.com/cockroachdb/errors"
    36  	"github.com/stretchr/testify/require"
    37  )
    38  
    39  // memoryTestCase is a helper struct for a test with memory limits.
    40  type memoryTestCase struct {
    41  	// bytes is the memory limit.
    42  	bytes int64
    43  	// skipExpSpillCheck specifies whether expSpill should be checked to assert
    44  	// that expected spilling behavior happened. This is true if bytes was
    45  	// randomly generated.
    46  	skipExpSpillCheck bool
    47  	// expSpill specifies whether a spill is expected or not. Should be ignored if
    48  	// skipExpSpillCheck is true.
    49  	expSpill bool
    50  }
    51  
    52  // getDiskqueueCfgAndMemoryTestCases is a test helper that creates an in-memory
    53  // DiskQueueCfg that can be used to create a new DiskQueue. A cleanup function
    54  // is also returned as well as some default memory limits that are useful to
    55  // test with: 0 for an immediate spill, a random memory limit up to 64 MiB, and
    56  // 1GiB, which shouldn't result in a spill.
    57  // Note that not all tests will check for a spill, it is enough that some
    58  // deterministic tests do so for the simple cases.
    59  // TODO(asubiotto): We might want to also return a verify() function that will
    60  //  check for leftover files.
    61  func getDiskQueueCfgAndMemoryTestCases(
    62  	t *testing.T, rng *rand.Rand,
    63  ) (colcontainer.DiskQueueCfg, func(), []memoryTestCase) {
    64  	t.Helper()
    65  	queueCfg, cleanup := colcontainerutils.NewTestingDiskQueueCfg(t, true /* inMem */)
    66  
    67  	return queueCfg, cleanup, []memoryTestCase{
    68  		{bytes: 0, expSpill: true},
    69  		{bytes: 1 + rng.Int63n(64<<20 /* 64 MiB */), skipExpSpillCheck: true},
    70  		{bytes: 1 << 30 /* 1 GiB */, expSpill: false},
    71  	}
    72  }
    73  
    74  // getDataAndFullSelection is a test helper that generates tuples representing
    75  // a batch with single int64 column where each element is its ordinal and an
    76  // accompanying selection vector that selects every index in tuples.
    77  func getDataAndFullSelection() (tuples, []*types.T, []int) {
    78  	data := make(tuples, coldata.BatchSize())
    79  	fullSelection := make([]int, coldata.BatchSize())
    80  	for i := range data {
    81  		data[i] = tuple{i}
    82  		fullSelection[i] = i
    83  	}
    84  	return data, []*types.T{types.Int}, fullSelection
    85  }
    86  
    87  func TestRouterOutputAddBatch(t *testing.T) {
    88  	defer leaktest.AfterTest(t)()
    89  	ctx := context.Background()
    90  
    91  	data, typs, fullSelection := getDataAndFullSelection()
    92  
    93  	// Since the actual data doesn't matter, we will just be reusing data for each
    94  	// test case.
    95  	testCases := []struct {
    96  		inputBatchSize   int
    97  		outputBatchSize  int
    98  		blockedThreshold int
    99  		// selection determines which indices to add to the router output as well
   100  		// as how many elements from data are compared to the output.
   101  		selection []int
   102  		name      string
   103  	}{
   104  		{
   105  			inputBatchSize:   coldata.BatchSize(),
   106  			outputBatchSize:  coldata.BatchSize(),
   107  			blockedThreshold: getDefaultRouterOutputBlockedThreshold(),
   108  			selection:        fullSelection,
   109  			name:             "OneBatch",
   110  		},
   111  		{
   112  			inputBatchSize:   coldata.BatchSize(),
   113  			outputBatchSize:  4,
   114  			blockedThreshold: getDefaultRouterOutputBlockedThreshold(),
   115  			selection:        fullSelection,
   116  			name:             "OneBatchGTOutputSize",
   117  		},
   118  		{
   119  			inputBatchSize:   4,
   120  			outputBatchSize:  coldata.BatchSize(),
   121  			blockedThreshold: getDefaultRouterOutputBlockedThreshold(),
   122  			selection:        fullSelection,
   123  			name:             "MultipleInputBatchesLTOutputSize",
   124  		},
   125  		{
   126  			inputBatchSize:   coldata.BatchSize(),
   127  			outputBatchSize:  coldata.BatchSize(),
   128  			blockedThreshold: getDefaultRouterOutputBlockedThreshold(),
   129  			selection:        fullSelection[:len(fullSelection)/4],
   130  			name:             "QuarterSelection",
   131  		},
   132  	}
   133  
   134  	// unblockEventsChan is purposefully unbuffered; the router output should never write to it
   135  	// in this test.
   136  	unblockEventsChan := make(chan struct{})
   137  
   138  	rng, _ := randutil.NewPseudoRand()
   139  	queueCfg, cleanup, memoryTestCases := getDiskQueueCfgAndMemoryTestCases(t, rng)
   140  	defer cleanup()
   141  
   142  	for _, tc := range testCases {
   143  		if len(tc.selection) == 0 {
   144  			// No data to work with, probably due to a low coldata.BatchSize.
   145  			continue
   146  		}
   147  		for _, mtc := range memoryTestCases {
   148  			t.Run(fmt.Sprintf("%s/memoryLimit=%s", tc.name, humanizeutil.IBytes(mtc.bytes)), func(t *testing.T) {
   149  				// Clear the testAllocator for use.
   150  				testAllocator.ReleaseMemory(testAllocator.Used())
   151  				o := newRouterOutputOpWithBlockedThresholdAndBatchSize(testAllocator, typs, unblockEventsChan, mtc.bytes, queueCfg, colexecbase.NewTestingSemaphore(2), tc.blockedThreshold, tc.outputBatchSize, testDiskAcc)
   152  				in := newOpTestInput(tc.inputBatchSize, data, nil /* typs */)
   153  				out := newOpTestOutput(o, data[:len(tc.selection)])
   154  				in.Init()
   155  				for {
   156  					b := in.Next(ctx)
   157  					o.addBatch(ctx, b, tc.selection)
   158  					if b.Length() == 0 {
   159  						break
   160  					}
   161  				}
   162  				if err := out.Verify(); err != nil {
   163  					t.Fatal(err)
   164  				}
   165  
   166  				// The output should never block. This assumes test cases never send more
   167  				// than defaultRouterOutputBlockedThreshold values.
   168  				select {
   169  				case b := <-unblockEventsChan:
   170  					t.Fatalf("unexpected output state change blocked: %t", b)
   171  				default:
   172  				}
   173  
   174  				if !mtc.skipExpSpillCheck {
   175  					require.Equal(t, mtc.expSpill, o.mu.data.spilled())
   176  				}
   177  			})
   178  		}
   179  	}
   180  }
   181  
   182  func TestRouterOutputNext(t *testing.T) {
   183  	defer leaktest.AfterTest(t)()
   184  	ctx := context.Background()
   185  
   186  	data, typs, fullSelection := getDataAndFullSelection()
   187  
   188  	testCases := []struct {
   189  		unblockEvent func(in colexecbase.Operator, o *routerOutputOp)
   190  		expected     tuples
   191  		name         string
   192  	}{
   193  		{
   194  			// ReaderWaitsForData verifies that a reader blocks in Next(ctx) until there
   195  			// is data available.
   196  			unblockEvent: func(in colexecbase.Operator, o *routerOutputOp) {
   197  				for {
   198  					b := in.Next(ctx)
   199  					o.addBatch(ctx, b, fullSelection)
   200  					if b.Length() == 0 {
   201  						break
   202  					}
   203  				}
   204  			},
   205  			expected: data,
   206  			name:     "ReaderWaitsForData",
   207  		},
   208  		{
   209  			// ReaderWaitsForZeroBatch verifies that a reader blocking on Next will
   210  			// also get unblocked with no data other than the zero batch.
   211  			unblockEvent: func(_ colexecbase.Operator, o *routerOutputOp) {
   212  				o.addBatch(ctx, coldata.ZeroBatch, nil /* selection */)
   213  			},
   214  			expected: tuples{},
   215  			name:     "ReaderWaitsForZeroBatch",
   216  		},
   217  		{
   218  			// CancelUnblocksReader verifies that calling cancel on an output unblocks
   219  			// a reader.
   220  			unblockEvent: func(_ colexecbase.Operator, o *routerOutputOp) {
   221  				o.cancel(ctx)
   222  			},
   223  			expected: tuples{},
   224  			name:     "CancelUnblocksReader",
   225  		},
   226  	}
   227  
   228  	// unblockedEventsChan is purposefully unbuffered; the router output should
   229  	// never write to it in this test.
   230  	unblockedEventsChan := make(chan struct{})
   231  
   232  	rng, _ := randutil.NewPseudoRand()
   233  	queueCfg, cleanup, memoryTestCases := getDiskQueueCfgAndMemoryTestCases(t, rng)
   234  	defer cleanup()
   235  
   236  	for _, mtc := range memoryTestCases {
   237  		for _, tc := range testCases {
   238  			t.Run(fmt.Sprintf("%s/memoryLimit=%s", tc.name, humanizeutil.IBytes(mtc.bytes)), func(t *testing.T) {
   239  				var wg sync.WaitGroup
   240  				batchChan := make(chan coldata.Batch)
   241  				if queueCfg.FS == nil {
   242  					t.Fatal("FS was nil")
   243  				}
   244  				o := newRouterOutputOp(testAllocator, typs, unblockedEventsChan, mtc.bytes, queueCfg, colexecbase.NewTestingSemaphore(2), testDiskAcc)
   245  				in := newOpTestInput(coldata.BatchSize(), data, nil /* typs */)
   246  				in.Init()
   247  				wg.Add(1)
   248  				go func() {
   249  					for {
   250  						b := o.Next(ctx)
   251  						batchChan <- b
   252  						if b.Length() == 0 {
   253  							break
   254  						}
   255  					}
   256  					wg.Done()
   257  				}()
   258  
   259  				// Sleep a long enough amount of time to make sure that if Next didn't block
   260  				// above, we have a good chance of reading a batch.
   261  				time.Sleep(time.Millisecond)
   262  				select {
   263  				case <-batchChan:
   264  					t.Fatal("expected reader goroutine to block when no data ready")
   265  				default:
   266  				}
   267  
   268  				tc.unblockEvent(in, o)
   269  
   270  				// Should have data available, pushed by our reader goroutine.
   271  				batches := colexecbase.NewBatchBuffer()
   272  				out := newOpTestOutput(batches, tc.expected)
   273  				for {
   274  					b := <-batchChan
   275  					batches.Add(b, typs)
   276  					if b.Length() == 0 {
   277  						break
   278  					}
   279  				}
   280  				if err := out.Verify(); err != nil {
   281  					t.Fatal(err)
   282  				}
   283  				wg.Wait()
   284  
   285  				select {
   286  				case <-unblockedEventsChan:
   287  					t.Fatal("unexpected output state change")
   288  				default:
   289  				}
   290  			})
   291  		}
   292  
   293  		t.Run(fmt.Sprintf("NextAfterZeroBatchDoesntBlock/memoryLimit=%s", humanizeutil.IBytes(mtc.bytes)), func(t *testing.T) {
   294  			o := newRouterOutputOp(testAllocator, typs, unblockedEventsChan, mtc.bytes, queueCfg, colexecbase.NewTestingSemaphore(2), testDiskAcc)
   295  			o.addBatch(ctx, coldata.ZeroBatch, fullSelection)
   296  			o.Next(ctx)
   297  			o.Next(ctx)
   298  			select {
   299  			case <-unblockedEventsChan:
   300  				t.Fatal("unexpected output state change")
   301  			default:
   302  			}
   303  		})
   304  
   305  		t.Run(fmt.Sprintf("AddBatchDoesntBlockWhenOutputIsBlocked/memoryLimit=%s", humanizeutil.IBytes(mtc.bytes)), func(t *testing.T) {
   306  			var (
   307  				smallBatchSize = 8
   308  				blockThreshold = smallBatchSize / 2
   309  			)
   310  
   311  			if len(fullSelection) <= smallBatchSize {
   312  				// If a full batch is smaller than our small batch size, reduce it, since
   313  				// this test relies on multiple batches returned from the input.
   314  				smallBatchSize = 2
   315  				if smallBatchSize >= minBatchSize {
   316  					// Sanity check.
   317  					t.Fatalf("smallBatchSize=%d still too large (must be less than minBatchSize=%d)", smallBatchSize, minBatchSize)
   318  				}
   319  				blockThreshold = 1
   320  			}
   321  
   322  			// Use a smaller selection than the batch size; it increases test coverage.
   323  			selection := fullSelection[:blockThreshold]
   324  
   325  			expected := make(tuples, 0, len(data))
   326  			for i := 0; i < len(data); i += smallBatchSize {
   327  				for k := 0; k < blockThreshold && i+k < len(data); k++ {
   328  					expected = append(expected, data[i+k])
   329  				}
   330  			}
   331  
   332  			ch := make(chan struct{}, 2)
   333  			o := newRouterOutputOpWithBlockedThresholdAndBatchSize(testAllocator, typs, ch, mtc.bytes, queueCfg, colexecbase.NewTestingSemaphore(2), blockThreshold, coldata.BatchSize(), testDiskAcc)
   334  			in := newOpTestInput(smallBatchSize, data, nil /* typs */)
   335  			out := newOpTestOutput(o, expected)
   336  			in.Init()
   337  
   338  			b := in.Next(ctx)
   339  			// Make sure the output doesn't consider itself blocked. We're right at the
   340  			// limit but not over.
   341  			if o.addBatch(ctx, b, selection) {
   342  				t.Fatal("unexpectedly blocked")
   343  			}
   344  			b = in.Next(ctx)
   345  			// This addBatch call should now block the output.
   346  			if !o.addBatch(ctx, b, selection) {
   347  				t.Fatal("unexpectedly still unblocked")
   348  			}
   349  
   350  			// Add the rest of the data.
   351  			for {
   352  				b = in.Next(ctx)
   353  				if o.addBatch(ctx, b, selection) {
   354  					t.Fatal("should only return true when switching from unblocked to blocked")
   355  				}
   356  				if b.Length() == 0 {
   357  					break
   358  				}
   359  			}
   360  
   361  			// Unblock the output.
   362  			if err := out.Verify(); err != nil {
   363  				t.Fatal(err)
   364  			}
   365  
   366  			// Verify that an unblock event is sent on the channel. This test will fail
   367  			// with a timeout on a channel read if not.
   368  			<-ch
   369  		})
   370  	}
   371  }
   372  
   373  func TestRouterOutputRandom(t *testing.T) {
   374  	defer leaktest.AfterTest(t)()
   375  	ctx := context.Background()
   376  
   377  	rng, _ := randutil.NewPseudoRand()
   378  
   379  	var (
   380  		maxValues        = coldata.BatchSize() * 4
   381  		blockedThreshold = 1 + rng.Intn(maxValues-1)
   382  		outputSize       = 1 + rng.Intn(maxValues-1)
   383  	)
   384  
   385  	typs := []*types.T{types.Int, types.Int}
   386  
   387  	dataLen := 1 + rng.Intn(maxValues-1)
   388  	data := make(tuples, dataLen)
   389  	for i := range data {
   390  		data[i] = make(tuple, len(typs))
   391  		for j := range typs {
   392  			data[i][j] = rng.Int63()
   393  		}
   394  	}
   395  
   396  	queueCfg, cleanup, memoryTestCases := getDiskQueueCfgAndMemoryTestCases(t, rng)
   397  	defer cleanup()
   398  
   399  	testName := fmt.Sprintf(
   400  		"blockedThreshold=%d/outputSize=%d/totalInputSize=%d", blockedThreshold, outputSize, len(data),
   401  	)
   402  	for _, mtc := range memoryTestCases {
   403  		t.Run(fmt.Sprintf("%s/memoryLimit=%s", testName, humanizeutil.IBytes(mtc.bytes)), func(t *testing.T) {
   404  			runTestsWithFn(t, []tuples{data}, nil /* typs */, func(t *testing.T, inputs []colexecbase.Operator) {
   405  				var wg sync.WaitGroup
   406  				unblockedEventsChans := make(chan struct{}, 2)
   407  				o := newRouterOutputOpWithBlockedThresholdAndBatchSize(testAllocator, typs, unblockedEventsChans, mtc.bytes, queueCfg, colexecbase.NewTestingSemaphore(2), blockedThreshold, outputSize, testDiskAcc)
   408  				inputs[0].Init()
   409  
   410  				expected := make(tuples, 0, len(data))
   411  
   412  				// canceled is a boolean that specifies whether the output was canceled.
   413  				// If this is the case, the output should not be verified.
   414  				canceled := false
   415  
   416  				// Producer.
   417  				errCh := make(chan error)
   418  				wg.Add(1)
   419  				go func() {
   420  					defer wg.Done()
   421  					lastBlockedState := false
   422  					for {
   423  						b := inputs[0].Next(ctx)
   424  						selection := b.Selection()
   425  						if selection == nil {
   426  							selection = coldatatestutils.RandomSel(rng, b.Length(), rng.Float64())
   427  						}
   428  
   429  						selection = selection[:b.Length()]
   430  
   431  						for _, i := range selection {
   432  							expected = append(expected, make(tuple, len(typs)))
   433  							for j := range typs {
   434  								expected[len(expected)-1][j] = b.ColVec(j).Int64()[i]
   435  							}
   436  						}
   437  
   438  						if o.addBatch(ctx, b, selection) {
   439  							if lastBlockedState {
   440  								// We might have missed an unblock event during the last loop.
   441  								select {
   442  								case <-unblockedEventsChans:
   443  								default:
   444  									errCh <- errors.New("output returned state change to blocked when already blocked")
   445  								}
   446  							}
   447  							lastBlockedState = true
   448  						}
   449  
   450  						if rng.Float64() < 0.1 {
   451  							o.cancel(ctx)
   452  							canceled = true
   453  							errCh <- nil
   454  							return
   455  						}
   456  
   457  						// Read any state changes.
   458  						for moreToRead := true; moreToRead; {
   459  							select {
   460  							case <-unblockedEventsChans:
   461  								if !lastBlockedState {
   462  									errCh <- errors.New("received unblocked state change when output is already unblocked")
   463  								}
   464  								lastBlockedState = false
   465  							default:
   466  								moreToRead = false
   467  							}
   468  						}
   469  
   470  						if b.Length() == 0 {
   471  							errCh <- nil
   472  							return
   473  						}
   474  					}
   475  				}()
   476  
   477  				actual := colexecbase.NewBatchBuffer()
   478  
   479  				// Consumer.
   480  				wg.Add(1)
   481  				go func() {
   482  					for {
   483  						b := o.Next(ctx)
   484  						actual.Add(coldatatestutils.CopyBatch(b, typs, testColumnFactory), typs)
   485  						if b.Length() == 0 {
   486  							wg.Done()
   487  							return
   488  						}
   489  					}
   490  				}()
   491  
   492  				if err := <-errCh; err != nil {
   493  					t.Fatal(err)
   494  				}
   495  
   496  				wg.Wait()
   497  
   498  				if canceled {
   499  					return
   500  				}
   501  
   502  				if err := newOpTestOutput(actual, expected).Verify(); err != nil {
   503  					t.Fatal(err)
   504  				}
   505  			})
   506  		})
   507  	}
   508  }
   509  
   510  type callbackRouterOutput struct {
   511  	colexecbase.ZeroInputNode
   512  	addBatchCb func(coldata.Batch, []int) bool
   513  	cancelCb   func()
   514  }
   515  
   516  var _ routerOutput = callbackRouterOutput{}
   517  
   518  func (o callbackRouterOutput) addBatch(
   519  	ctx context.Context, batch coldata.Batch, selection []int,
   520  ) bool {
   521  	if o.addBatchCb != nil {
   522  		return o.addBatchCb(batch, selection)
   523  	}
   524  	return false
   525  }
   526  
   527  func (o callbackRouterOutput) cancel(context.Context) {
   528  	if o.cancelCb != nil {
   529  		o.cancelCb()
   530  	}
   531  }
   532  
   533  func (o callbackRouterOutput) drain() []execinfrapb.ProducerMetadata {
   534  	return nil
   535  }
   536  
   537  func TestHashRouterComputesDestination(t *testing.T) {
   538  	defer leaktest.AfterTest(t)()
   539  	ctx := context.Background()
   540  
   541  	// We have precomputed expectedNumVals only for the default batch size, so we
   542  	// will override it if a different value is set.
   543  	const expectedBatchSize = 1024
   544  	batchSize := coldata.BatchSize()
   545  	if batchSize != expectedBatchSize {
   546  		require.NoError(t, coldata.SetBatchSizeForTests(expectedBatchSize))
   547  		defer func(batchSize int) { require.NoError(t, coldata.SetBatchSizeForTests(batchSize)) }(batchSize)
   548  		batchSize = expectedBatchSize
   549  	}
   550  	data := make(tuples, batchSize)
   551  	valsYetToSee := make(map[int64]struct{})
   552  	for i := range data {
   553  		data[i] = tuple{i}
   554  		valsYetToSee[int64(i)] = struct{}{}
   555  	}
   556  
   557  	in := newOpTestInput(batchSize, data, nil /* typs */)
   558  	in.Init()
   559  
   560  	var (
   561  		// expectedNumVals is the number of expected values the output at the
   562  		// corresponding index in outputs receives. This should not change between
   563  		// runs of tests unless the underlying hash algorithm changes. If it does,
   564  		// distributed hash routing will not produce correct results.
   565  		expectedNumVals = []int{273, 252, 287, 212}
   566  		numOutputs      = 4
   567  		valsPushed      = make([]int, numOutputs)
   568  		typs            = []*types.T{types.Int}
   569  	)
   570  
   571  	outputs := make([]routerOutput, numOutputs)
   572  	for i := range outputs {
   573  		// Capture the index.
   574  		outputIdx := i
   575  		outputs[i] = callbackRouterOutput{
   576  			addBatchCb: func(batch coldata.Batch, sel []int) bool {
   577  				for _, j := range sel {
   578  					key := batch.ColVec(0).Int64()[j]
   579  					if _, ok := valsYetToSee[key]; !ok {
   580  						t.Fatalf("pushed alread seen value to router output: %d", key)
   581  					}
   582  					delete(valsYetToSee, key)
   583  					valsPushed[outputIdx]++
   584  				}
   585  				return false
   586  			},
   587  			cancelCb: func() {
   588  				t.Fatalf(
   589  					"output %d canceled, outputs should not be canceled during normal operation", outputIdx,
   590  				)
   591  			},
   592  		}
   593  	}
   594  
   595  	r := newHashRouterWithOutputs(in, typs, []uint32{0}, nil /* ch */, outputs, nil /* toClose */)
   596  	for r.processNextBatch(ctx) {
   597  	}
   598  
   599  	if len(valsYetToSee) != 0 {
   600  		t.Fatalf("hash router failed to push values: %v", valsYetToSee)
   601  	}
   602  
   603  	for i, expected := range expectedNumVals {
   604  		if valsPushed[i] != expected {
   605  			t.Fatalf("num val slices differ at output %d, expected: %v actual: %v", i, expectedNumVals, valsPushed)
   606  		}
   607  	}
   608  }
   609  
   610  func TestHashRouterCancellation(t *testing.T) {
   611  	defer leaktest.AfterTest(t)()
   612  
   613  	outputs := make([]routerOutput, 4)
   614  	numCancels := int64(0)
   615  	numAddBatches := int64(0)
   616  	for i := range outputs {
   617  		// We'll just be checking canceled.
   618  		outputs[i] = callbackRouterOutput{
   619  			addBatchCb: func(_ coldata.Batch, _ []int) bool {
   620  				atomic.AddInt64(&numAddBatches, 1)
   621  				return false
   622  			},
   623  			cancelCb: func() { atomic.AddInt64(&numCancels, 1) },
   624  		}
   625  	}
   626  
   627  	typs := []*types.T{types.Int}
   628  	// Never-ending input of 0s.
   629  	batch := testAllocator.NewMemBatch(typs)
   630  	batch.SetLength(coldata.BatchSize())
   631  	in := colexecbase.NewRepeatableBatchSource(testAllocator, batch, typs)
   632  
   633  	unbufferedCh := make(chan struct{})
   634  	r := newHashRouterWithOutputs(in, typs, []uint32{0}, unbufferedCh, outputs, nil /* toClose */)
   635  
   636  	t.Run("BeforeRun", func(t *testing.T) {
   637  		ctx, cancel := context.WithCancel(context.Background())
   638  		cancel()
   639  		r.Run(ctx)
   640  
   641  		if numCancels != int64(len(outputs)) {
   642  			t.Fatalf("expected %d canceled outputs, actual %d", len(outputs), numCancels)
   643  		}
   644  
   645  		if numAddBatches != 0 {
   646  			t.Fatalf("detected %d addBatch calls but expected 0", numAddBatches)
   647  		}
   648  
   649  		meta := r.DrainMeta(ctx)
   650  		require.Equal(t, 1, len(meta))
   651  		require.True(t, testutils.IsError(meta[0].Err, "context canceled"), meta[0].Err)
   652  	})
   653  
   654  	testCases := []struct {
   655  		blocked bool
   656  		name    string
   657  	}{
   658  		{
   659  			blocked: false,
   660  			name:    "DuringRun",
   661  		},
   662  		{
   663  			blocked: true,
   664  			name:    "WhileWaitingForUnblock",
   665  		},
   666  	}
   667  
   668  	for _, tc := range testCases {
   669  		t.Run(tc.name, func(t *testing.T) {
   670  			numCancels = 0
   671  			numAddBatches = 0
   672  
   673  			ctx, cancel := context.WithCancel(context.Background())
   674  
   675  			if tc.blocked {
   676  				r.numBlockedOutputs = len(outputs)
   677  				defer func() {
   678  					r.numBlockedOutputs = 0
   679  				}()
   680  			}
   681  
   682  			routerMeta := make(chan []execinfrapb.ProducerMetadata)
   683  			go func() {
   684  				r.Run(ctx)
   685  				routerMeta <- r.DrainMeta(ctx)
   686  				close(routerMeta)
   687  			}()
   688  
   689  			time.Sleep(time.Millisecond)
   690  			if tc.blocked {
   691  				// Make sure no addBatches happened.
   692  				if n := atomic.LoadInt64(&numAddBatches); n != 0 {
   693  					t.Fatalf("expected router to be blocked, but detected %d addBatch calls", n)
   694  				}
   695  			}
   696  			select {
   697  			case <-routerMeta:
   698  				t.Fatal("hash router goroutine unexpectedly done")
   699  			default:
   700  			}
   701  			cancel()
   702  			meta := <-routerMeta
   703  			require.Equal(t, 1, len(meta))
   704  			require.True(t, testutils.IsError(meta[0].Err, "canceled"), meta[0].Err)
   705  
   706  			if numCancels != int64(len(outputs)) {
   707  				t.Fatalf("expected %d canceled outputs, actual %d", len(outputs), numCancels)
   708  			}
   709  		})
   710  	}
   711  }
   712  
   713  func TestHashRouterOneOutput(t *testing.T) {
   714  	defer leaktest.AfterTest(t)()
   715  	ctx := context.Background()
   716  
   717  	rng, _ := randutil.NewPseudoRand()
   718  
   719  	sel := coldatatestutils.RandomSel(rng, coldata.BatchSize(), rng.Float64())
   720  
   721  	data, typs, _ := getDataAndFullSelection()
   722  
   723  	expected := make(tuples, 0, len(data))
   724  	for _, i := range sel {
   725  		expected = append(expected, data[i])
   726  	}
   727  
   728  	queueCfg, cleanup, memoryTestCases := getDiskQueueCfgAndMemoryTestCases(t, rng)
   729  	defer cleanup()
   730  
   731  	for _, mtc := range memoryTestCases {
   732  		t.Run(fmt.Sprintf("memoryLimit=%s", humanizeutil.IBytes(mtc.bytes)), func(t *testing.T) {
   733  			// Clear the testAllocator for use.
   734  			testAllocator.ReleaseMemory(testAllocator.Used())
   735  			diskAcc := testDiskMonitor.MakeBoundAccount()
   736  			defer diskAcc.Close(ctx)
   737  			r, routerOutputs := NewHashRouter(
   738  				[]*colmem.Allocator{testAllocator}, newOpFixedSelTestInput(sel, len(sel), data, typs),
   739  				typs, []uint32{0}, mtc.bytes, queueCfg, colexecbase.NewTestingSemaphore(2),
   740  				[]*mon.BoundAccount{&diskAcc}, nil, /* toClose */
   741  			)
   742  
   743  			if len(routerOutputs) != 1 {
   744  				t.Fatalf("expected 1 router output but got %d", len(routerOutputs))
   745  			}
   746  
   747  			o := newOpTestOutput(routerOutputs[0], expected)
   748  
   749  			ro := routerOutputs[0].(*routerOutputOp)
   750  			// Set alwaysFlush so that data is always flushed to the spillingQueue.
   751  			ro.testingKnobs.alwaysFlush = true
   752  
   753  			var wg sync.WaitGroup
   754  			wg.Add(1)
   755  			go func() {
   756  				r.Run(ctx)
   757  				wg.Done()
   758  			}()
   759  
   760  			if err := o.Verify(); err != nil {
   761  				t.Fatal(err)
   762  			}
   763  			wg.Wait()
   764  			// Expect no metadata, this should be a successful run.
   765  			unexpectedMetadata := r.DrainMeta(ctx)
   766  			if len(unexpectedMetadata) != 0 {
   767  				t.Fatalf("unexpected metadata when draining HashRouter: %+v", unexpectedMetadata)
   768  			}
   769  			if !mtc.skipExpSpillCheck {
   770  				// If len(sel) == 0, no items will have been enqueued so override an
   771  				// expected spill if this is the case.
   772  				mtc.expSpill = mtc.expSpill && len(sel) != 0
   773  				require.Equal(t, mtc.expSpill, ro.mu.data.spilled())
   774  			}
   775  		})
   776  	}
   777  }
   778  
   779  func TestHashRouterRandom(t *testing.T) {
   780  	defer leaktest.AfterTest(t)()
   781  	ctx := context.Background()
   782  
   783  	rng, _ := randutil.NewPseudoRand()
   784  
   785  	var (
   786  		maxValues        = coldata.BatchSize() * 4
   787  		maxOutputs       = 128
   788  		blockedThreshold = 1 + rng.Intn(maxValues-1)
   789  		outputSize       = 1 + rng.Intn(maxValues-1)
   790  		numOutputs       = 1 + rng.Intn(maxOutputs-1)
   791  	)
   792  
   793  	typs := []*types.T{types.Int, types.Int}
   794  	dataLen := 1 + rng.Intn(maxValues-1)
   795  	data := make(tuples, dataLen)
   796  	for i := range data {
   797  		data[i] = make(tuple, len(typs))
   798  		for j := range typs {
   799  			data[i][j] = rng.Int63()
   800  		}
   801  	}
   802  
   803  	hashCols := make([]uint32, 0, len(typs))
   804  	hashCols = append(hashCols, 0)
   805  	for i := 1; i < cap(hashCols); i++ {
   806  		if rng.Float64() < 0.5 {
   807  			hashCols = append(hashCols, uint32(i))
   808  		}
   809  	}
   810  
   811  	// cancel determines whether we test cancellation.
   812  	cancel := false
   813  	if rng.Float64() < 0.25 {
   814  		cancel = true
   815  	}
   816  
   817  	testName := fmt.Sprintf(
   818  		"numOutputs=%d/blockedThreshold=%d/outputSize=%d/totalInputSize=%d/hashCols=%v/cancel=%t",
   819  		numOutputs,
   820  		blockedThreshold,
   821  		outputSize,
   822  		len(data),
   823  		hashCols,
   824  		cancel,
   825  	)
   826  
   827  	queueCfg, cleanup, memoryTestCases := getDiskQueueCfgAndMemoryTestCases(t, rng)
   828  	defer cleanup()
   829  
   830  	// expectedDistribution is set after the first run and used to verify that the
   831  	// distribution of results does not change between runs, as we are sending the
   832  	// same data to the same number of outputs.
   833  	var expectedDistribution []int
   834  	for _, mtc := range memoryTestCases {
   835  		t.Run(fmt.Sprintf(testName+"/memoryLimit=%s", humanizeutil.IBytes(mtc.bytes)), func(t *testing.T) {
   836  			runTestsWithFn(t, []tuples{data}, nil /* typs */, func(t *testing.T, inputs []colexecbase.Operator) {
   837  				unblockEventsChan := make(chan struct{}, 2*numOutputs)
   838  				outputs := make([]routerOutput, numOutputs)
   839  				outputsAsOps := make([]colexecbase.Operator, numOutputs)
   840  				memoryLimitPerOutput := mtc.bytes / int64(len(outputs))
   841  				for i := range outputs {
   842  					// Create separate monitoring infrastructure as well as
   843  					// an allocator for each output as router outputs run
   844  					// concurrently.
   845  					acc := testMemMonitor.MakeBoundAccount()
   846  					defer acc.Close(ctx)
   847  					diskAcc := testDiskMonitor.MakeBoundAccount()
   848  					defer diskAcc.Close(ctx)
   849  					allocator := colmem.NewAllocator(ctx, &acc, testColumnFactory)
   850  					op := newRouterOutputOpWithBlockedThresholdAndBatchSize(allocator, typs, unblockEventsChan, memoryLimitPerOutput, queueCfg, colexecbase.NewTestingSemaphore(len(outputs)*2), blockedThreshold, outputSize, &diskAcc)
   851  					outputs[i] = op
   852  					outputsAsOps[i] = op
   853  				}
   854  
   855  				r := newHashRouterWithOutputs(
   856  					inputs[0], typs, hashCols, unblockEventsChan, outputs, nil, /* toClose */
   857  				)
   858  
   859  				var (
   860  					results uint64
   861  					wg      sync.WaitGroup
   862  				)
   863  				resultsByOp := make([]int, len(outputsAsOps))
   864  				wg.Add(len(outputsAsOps))
   865  				for i := range outputsAsOps {
   866  					go func(i int) {
   867  						for {
   868  							b := outputsAsOps[i].Next(ctx)
   869  							if b.Length() == 0 {
   870  								break
   871  							}
   872  							atomic.AddUint64(&results, uint64(b.Length()))
   873  							resultsByOp[i] += b.Length()
   874  						}
   875  						wg.Done()
   876  					}(i)
   877  				}
   878  
   879  				ctx, cancelFunc := context.WithCancel(context.Background())
   880  				wg.Add(1)
   881  				go func() {
   882  					r.Run(ctx)
   883  					wg.Done()
   884  				}()
   885  
   886  				if cancel {
   887  					// Sleep between 0 and ~5 milliseconds.
   888  					time.Sleep(time.Microsecond * time.Duration(rng.Intn(5000)))
   889  					cancelFunc()
   890  				} else {
   891  					// Satisfy linter context leak error.
   892  					defer cancelFunc()
   893  				}
   894  
   895  				// Ensure all goroutines end. If a test fails with a hang here it is most
   896  				// likely due to a cancellation bug.
   897  				wg.Wait()
   898  				if !cancel {
   899  					// Expect no metadata, this should be a successful run.
   900  					unexpectedMetadata := r.DrainMeta(ctx)
   901  					if len(unexpectedMetadata) != 0 {
   902  						t.Fatalf("unexpected metadata when draining HashRouter: %+v", unexpectedMetadata)
   903  					}
   904  					// Only do output verification if no cancellation happened.
   905  					if actualTotal := atomic.LoadUint64(&results); actualTotal != uint64(len(data)) {
   906  						t.Fatalf("unexpected number of results %d, expected %d", actualTotal, len(data))
   907  					}
   908  					if expectedDistribution == nil {
   909  						expectedDistribution = resultsByOp
   910  						return
   911  					}
   912  					for i, numVals := range expectedDistribution {
   913  						if numVals != resultsByOp[i] {
   914  							t.Fatalf(
   915  								"distribution of results changed compared to first run at output %d. expected: %v, actual: %v",
   916  								i,
   917  								expectedDistribution,
   918  								resultsByOp,
   919  							)
   920  						}
   921  					}
   922  				}
   923  			})
   924  		})
   925  	}
   926  }
   927  
   928  func BenchmarkHashRouter(b *testing.B) {
   929  	defer leaktest.AfterTest(b)()
   930  	ctx := context.Background()
   931  
   932  	// Use only one type. Note: the more types you use, the more you inflate the
   933  	// numbers.
   934  	typs := []*types.T{types.Int}
   935  	batch := testAllocator.NewMemBatch(typs)
   936  	batch.SetLength(coldata.BatchSize())
   937  	input := colexecbase.NewRepeatableBatchSource(testAllocator, batch, typs)
   938  
   939  	queueCfg, cleanup := colcontainerutils.NewTestingDiskQueueCfg(b, true /* inMem */)
   940  	defer cleanup()
   941  
   942  	var wg sync.WaitGroup
   943  	for _, numOutputs := range []int{2, 4, 8, 16} {
   944  		for _, numInputBatches := range []int{2, 4, 8, 16} {
   945  			b.Run(fmt.Sprintf("numOutputs=%d/numInputBatches=%d", numOutputs, numInputBatches), func(b *testing.B) {
   946  				allocators := make([]*colmem.Allocator, numOutputs)
   947  				diskAccounts := make([]*mon.BoundAccount, numOutputs)
   948  				for i := range allocators {
   949  					acc := testMemMonitor.MakeBoundAccount()
   950  					allocators[i] = colmem.NewAllocator(ctx, &acc, testColumnFactory)
   951  					defer acc.Close(ctx)
   952  					diskAcc := testDiskMonitor.MakeBoundAccount()
   953  					diskAccounts[i] = &diskAcc
   954  					defer diskAcc.Close(ctx)
   955  				}
   956  				r, outputs := NewHashRouter(
   957  					allocators, input, typs, []uint32{0}, 64<<20,
   958  					queueCfg, &colexecbase.TestingSemaphore{}, diskAccounts, nil, /* toClose */
   959  				)
   960  				b.SetBytes(8 * int64(coldata.BatchSize()) * int64(numInputBatches))
   961  				// We expect distribution to not change. This is a sanity check that
   962  				// we're resetting properly.
   963  				var expectedDistribution []int
   964  				actualDistribution := make([]int, len(outputs))
   965  				// zeroDistribution just allows us to reset actualDistribution with a
   966  				// copy.
   967  				zeroDistribution := make([]int, len(outputs))
   968  				b.ResetTimer()
   969  				for i := 0; i < b.N; i++ {
   970  					input.ResetBatchesToReturn(numInputBatches)
   971  					r.resetForBenchmarks(ctx)
   972  					wg.Add(len(outputs))
   973  					for j := range outputs {
   974  						go func(j int) {
   975  							for {
   976  								oBatch := outputs[j].Next(ctx)
   977  								actualDistribution[j] += oBatch.Length()
   978  								if oBatch.Length() == 0 {
   979  									break
   980  								}
   981  							}
   982  							wg.Done()
   983  						}(j)
   984  					}
   985  					r.Run(ctx)
   986  					wg.Wait()
   987  					// sum sanity checks that we are actually pushing as many values as we
   988  					// expect.
   989  					sum := 0
   990  					for i := range actualDistribution {
   991  						sum += actualDistribution[i]
   992  					}
   993  					if sum != numInputBatches*coldata.BatchSize() {
   994  						b.Fatalf("unexpected sum %d, expected %d", sum, numInputBatches*coldata.BatchSize())
   995  					}
   996  					if expectedDistribution == nil {
   997  						expectedDistribution = make([]int, len(actualDistribution))
   998  						copy(expectedDistribution, actualDistribution)
   999  					} else {
  1000  						for j := range expectedDistribution {
  1001  							if expectedDistribution[j] != actualDistribution[j] {
  1002  								b.Fatalf(
  1003  									"not resetting properly expected distribution: %v, actual distribution: %v",
  1004  									expectedDistribution,
  1005  									actualDistribution,
  1006  								)
  1007  							}
  1008  						}
  1009  					}
  1010  					copy(actualDistribution, zeroDistribution)
  1011  				}
  1012  			})
  1013  		}
  1014  	}
  1015  }