github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/physicalplan/aggregator_funcs_test.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package physicalplan
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  	"math"
    17  	"math/big"
    18  	"testing"
    19  
    20  	"github.com/cockroachdb/cockroach/pkg/base"
    21  	"github.com/cockroachdb/cockroach/pkg/keys"
    22  	"github.com/cockroachdb/cockroach/pkg/kv"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/distsql"
    24  	"github.com/cockroachdb/cockroach/pkg/sql/execinfra"
    25  	"github.com/cockroachdb/cockroach/pkg/sql/execinfrapb"
    26  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    27  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    28  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    29  	"github.com/cockroachdb/cockroach/pkg/testutils/distsqlutils"
    30  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    31  	"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
    32  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    33  	"github.com/cockroachdb/cockroach/pkg/util/randutil"
    34  	"github.com/cockroachdb/cockroach/pkg/util/uuid"
    35  )
    36  
    37  var (
    38  	// diffCtx is a decimal context used to perform subtractions between
    39  	// local and non-local decimal results to check if they are within
    40  	// 1ulp. Decimals within 1ulp is acceptable for high-precision
    41  	// decimal calculations.
    42  	diffCtx = tree.DecimalCtx.WithPrecision(0)
    43  	// Use to check for 1ulp.
    44  	bigOne = big.NewInt(1)
    45  	// floatPrecFmt is the format string with a precision of 3 (after
    46  	// decimal point) specified for float comparisons. Float aggregation
    47  	// operations involve unavoidable off-by-last-few-digits errors, which
    48  	// is expected.
    49  	floatPrecFmt = "%.3f"
    50  )
    51  
    52  // runTestFlow runs a flow with the given processors and returns the results.
    53  // Any errors stop the current test.
    54  func runTestFlow(
    55  	t *testing.T,
    56  	srv serverutils.TestServerInterface,
    57  	txn *kv.Txn,
    58  	procs ...execinfrapb.ProcessorSpec,
    59  ) sqlbase.EncDatumRows {
    60  	distSQLSrv := srv.DistSQLServer().(*distsql.ServerImpl)
    61  
    62  	leafInputState := txn.GetLeafTxnInputState(context.Background())
    63  	req := execinfrapb.SetupFlowRequest{
    64  		Version:           execinfra.Version,
    65  		LeafTxnInputState: &leafInputState,
    66  		Flow: execinfrapb.FlowSpec{
    67  			FlowID:     execinfrapb.FlowID{UUID: uuid.MakeV4()},
    68  			Processors: procs,
    69  		},
    70  	}
    71  
    72  	var rowBuf distsqlutils.RowBuffer
    73  
    74  	ctx, flow, err := distSQLSrv.SetupSyncFlow(context.Background(), distSQLSrv.ParentMemoryMonitor, &req, &rowBuf)
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  	if err := flow.Start(ctx, func() {}); err != nil {
    79  		t.Fatal(err)
    80  	}
    81  	flow.Wait()
    82  	flow.Cleanup(ctx)
    83  
    84  	if !rowBuf.ProducerClosed() {
    85  		t.Errorf("output not closed")
    86  	}
    87  
    88  	var res sqlbase.EncDatumRows
    89  	for {
    90  		row, meta := rowBuf.Next()
    91  		if meta != nil {
    92  			if meta.LeafTxnFinalState != nil || meta.Metrics != nil {
    93  				continue
    94  			}
    95  			t.Fatalf("unexpected metadata: %v", meta)
    96  		}
    97  		if row == nil {
    98  			break
    99  		}
   100  		res = append(res, row)
   101  	}
   102  
   103  	return res
   104  }
   105  
   106  // checkDistAggregationInfo tests that a flow with multiple local stages and a
   107  // final stage (in accordance with per DistAggregationInfo) gets the same result
   108  // with a naive aggregation flow that has a single non-distributed stage.
   109  //
   110  // Both types of flows are set up and ran against the first numRows of the given
   111  // table. We assume the table's first column is the primary key, with values
   112  // from 1 to numRows. A non-PK column that works with the function is chosen.
   113  func checkDistAggregationInfo(
   114  	ctx context.Context,
   115  	t *testing.T,
   116  	srv serverutils.TestServerInterface,
   117  	tableDesc *sqlbase.TableDescriptor,
   118  	colIdx int,
   119  	numRows int,
   120  	fn execinfrapb.AggregatorSpec_Func,
   121  	info DistAggregationInfo,
   122  ) {
   123  	colType := tableDesc.Columns[colIdx].Type
   124  
   125  	makeTableReader := func(startPK, endPK int, streamID int) execinfrapb.ProcessorSpec {
   126  		tr := execinfrapb.TableReaderSpec{
   127  			Table: *tableDesc,
   128  			Spans: make([]execinfrapb.TableReaderSpan, 1),
   129  		}
   130  
   131  		var err error
   132  		tr.Spans[0].Span.Key, err = sqlbase.TestingMakePrimaryIndexKey(tableDesc, startPK)
   133  		if err != nil {
   134  			t.Fatal(err)
   135  		}
   136  		tr.Spans[0].Span.EndKey, err = sqlbase.TestingMakePrimaryIndexKey(tableDesc, endPK)
   137  		if err != nil {
   138  			t.Fatal(err)
   139  		}
   140  
   141  		return execinfrapb.ProcessorSpec{
   142  			Core: execinfrapb.ProcessorCoreUnion{TableReader: &tr},
   143  			Post: execinfrapb.PostProcessSpec{
   144  				Projection:    true,
   145  				OutputColumns: []uint32{uint32(colIdx)},
   146  			},
   147  			Output: []execinfrapb.OutputRouterSpec{{
   148  				Type: execinfrapb.OutputRouterSpec_PASS_THROUGH,
   149  				Streams: []execinfrapb.StreamEndpointSpec{
   150  					{Type: execinfrapb.StreamEndpointSpec_LOCAL, StreamID: execinfrapb.StreamID(streamID)},
   151  				},
   152  			}},
   153  		}
   154  	}
   155  
   156  	txn := kv.NewTxn(ctx, srv.DB(), srv.NodeID())
   157  
   158  	// First run a flow that aggregates all the rows without any local stages.
   159  
   160  	rowsNonDist := runTestFlow(
   161  		t, srv, txn,
   162  		makeTableReader(1, numRows+1, 0),
   163  		execinfrapb.ProcessorSpec{
   164  			Input: []execinfrapb.InputSyncSpec{{
   165  				Type:        execinfrapb.InputSyncSpec_UNORDERED,
   166  				ColumnTypes: []*types.T{colType},
   167  				Streams: []execinfrapb.StreamEndpointSpec{
   168  					{Type: execinfrapb.StreamEndpointSpec_LOCAL, StreamID: 0},
   169  				},
   170  			}},
   171  			Core: execinfrapb.ProcessorCoreUnion{Aggregator: &execinfrapb.AggregatorSpec{
   172  				Aggregations: []execinfrapb.AggregatorSpec_Aggregation{{Func: fn, ColIdx: []uint32{0}}},
   173  			}},
   174  			Output: []execinfrapb.OutputRouterSpec{{
   175  				Type: execinfrapb.OutputRouterSpec_PASS_THROUGH,
   176  				Streams: []execinfrapb.StreamEndpointSpec{
   177  					{Type: execinfrapb.StreamEndpointSpec_SYNC_RESPONSE},
   178  				},
   179  			}},
   180  		},
   181  	)
   182  
   183  	numIntermediary := len(info.LocalStage)
   184  	numFinal := len(info.FinalStage)
   185  	for _, finalInfo := range info.FinalStage {
   186  		if len(finalInfo.LocalIdxs) == 0 {
   187  			t.Fatalf("final stage must specify input local indices: %#v", info)
   188  		}
   189  		for _, localIdx := range finalInfo.LocalIdxs {
   190  			if localIdx >= uint32(numIntermediary) {
   191  				t.Fatalf("local index %d out of bounds of local stages: %#v", localIdx, info)
   192  			}
   193  		}
   194  	}
   195  
   196  	// Now run a flow with 4 separate table readers, each with its own local
   197  	// stage, all feeding into a single final stage.
   198  
   199  	numParallel := 4
   200  
   201  	// The type(s) outputted by the local stage can be different than the input type
   202  	// (e.g. DECIMAL instead of INT).
   203  	intermediaryTypes := make([]*types.T, numIntermediary)
   204  	for i, fn := range info.LocalStage {
   205  		var err error
   206  		_, returnTyp, err := execinfrapb.GetAggregateInfo(fn, colType)
   207  		if err != nil {
   208  			t.Fatal(err)
   209  		}
   210  		intermediaryTypes[i] = returnTyp
   211  	}
   212  
   213  	localAggregations := make([]execinfrapb.AggregatorSpec_Aggregation, numIntermediary)
   214  	for i, fn := range info.LocalStage {
   215  		// Local aggregations have the same input.
   216  		localAggregations[i] = execinfrapb.AggregatorSpec_Aggregation{Func: fn, ColIdx: []uint32{0}}
   217  	}
   218  	finalAggregations := make([]execinfrapb.AggregatorSpec_Aggregation, numFinal)
   219  	for i, finalInfo := range info.FinalStage {
   220  		// Each local aggregation feeds into a final aggregation.
   221  		finalAggregations[i] = execinfrapb.AggregatorSpec_Aggregation{
   222  			Func:   finalInfo.Fn,
   223  			ColIdx: finalInfo.LocalIdxs,
   224  		}
   225  	}
   226  
   227  	if numParallel < numRows {
   228  		numParallel = numRows
   229  	}
   230  	finalProc := execinfrapb.ProcessorSpec{
   231  		Input: []execinfrapb.InputSyncSpec{{
   232  			Type:        execinfrapb.InputSyncSpec_UNORDERED,
   233  			ColumnTypes: intermediaryTypes,
   234  		}},
   235  		Core: execinfrapb.ProcessorCoreUnion{Aggregator: &execinfrapb.AggregatorSpec{
   236  			Aggregations: finalAggregations,
   237  		}},
   238  		Output: []execinfrapb.OutputRouterSpec{{
   239  			Type: execinfrapb.OutputRouterSpec_PASS_THROUGH,
   240  			Streams: []execinfrapb.StreamEndpointSpec{
   241  				{Type: execinfrapb.StreamEndpointSpec_SYNC_RESPONSE},
   242  			},
   243  		}},
   244  	}
   245  
   246  	// The type(s) outputted by the final stage can be different than the
   247  	// input type (e.g. DECIMAL instead of INT).
   248  	finalOutputTypes := make([]*types.T, numFinal)
   249  	// Passed into FinalIndexing as the indices for the IndexedVars inputs
   250  	// to the post processor.
   251  	varIdxs := make([]int, numFinal)
   252  	for i, finalInfo := range info.FinalStage {
   253  		inputTypes := make([]*types.T, len(finalInfo.LocalIdxs))
   254  		for i, localIdx := range finalInfo.LocalIdxs {
   255  			inputTypes[i] = intermediaryTypes[localIdx]
   256  		}
   257  		var err error
   258  		_, finalOutputTypes[i], err = execinfrapb.GetAggregateInfo(finalInfo.Fn, inputTypes...)
   259  		if err != nil {
   260  			t.Fatal(err)
   261  		}
   262  		varIdxs[i] = i
   263  	}
   264  
   265  	var procs []execinfrapb.ProcessorSpec
   266  	for i := 0; i < numParallel; i++ {
   267  		tr := makeTableReader(1+i*numRows/numParallel, 1+(i+1)*numRows/numParallel, 2*i)
   268  		agg := execinfrapb.ProcessorSpec{
   269  			Input: []execinfrapb.InputSyncSpec{{
   270  				Type:        execinfrapb.InputSyncSpec_UNORDERED,
   271  				ColumnTypes: []*types.T{colType},
   272  				Streams: []execinfrapb.StreamEndpointSpec{
   273  					{Type: execinfrapb.StreamEndpointSpec_LOCAL, StreamID: execinfrapb.StreamID(2 * i)},
   274  				},
   275  			}},
   276  			Core: execinfrapb.ProcessorCoreUnion{Aggregator: &execinfrapb.AggregatorSpec{
   277  				Aggregations: localAggregations,
   278  			}},
   279  			Output: []execinfrapb.OutputRouterSpec{{
   280  				Type: execinfrapb.OutputRouterSpec_PASS_THROUGH,
   281  				Streams: []execinfrapb.StreamEndpointSpec{
   282  					{Type: execinfrapb.StreamEndpointSpec_LOCAL, StreamID: execinfrapb.StreamID(2*i + 1)},
   283  				},
   284  			}},
   285  		}
   286  		procs = append(procs, tr, agg)
   287  		finalProc.Input[0].Streams = append(finalProc.Input[0].Streams, execinfrapb.StreamEndpointSpec{
   288  			Type:     execinfrapb.StreamEndpointSpec_LOCAL,
   289  			StreamID: execinfrapb.StreamID(2*i + 1),
   290  		})
   291  	}
   292  
   293  	if info.FinalRendering != nil {
   294  		h := tree.MakeTypesOnlyIndexedVarHelper(finalOutputTypes)
   295  		renderExpr, err := info.FinalRendering(&h, varIdxs)
   296  		if err != nil {
   297  			t.Fatal(err)
   298  		}
   299  		var expr execinfrapb.Expression
   300  		expr, err = MakeExpression(renderExpr, nil, nil)
   301  		if err != nil {
   302  			t.Fatal(err)
   303  		}
   304  		finalProc.Post.RenderExprs = []execinfrapb.Expression{expr}
   305  
   306  	}
   307  
   308  	procs = append(procs, finalProc)
   309  	rowsDist := runTestFlow(t, srv, txn, procs...)
   310  
   311  	if len(rowsDist[0]) != len(rowsNonDist[0]) {
   312  		t.Errorf("different row lengths (dist: %d non-dist: %d)", len(rowsDist[0]), len(rowsNonDist[0]))
   313  	} else {
   314  		for i := range rowsDist[0] {
   315  			rowDist := rowsDist[0][i]
   316  			rowNonDist := rowsNonDist[0][i]
   317  			if rowDist.Datum.ResolvedType().Family() != rowNonDist.Datum.ResolvedType().Family() {
   318  				t.Fatalf("different type for column %d (dist: %s non-dist: %s)", i, rowDist.Datum.ResolvedType(), rowNonDist.Datum.ResolvedType())
   319  			}
   320  
   321  			var equiv bool
   322  			var strDist, strNonDist string
   323  			switch typedDist := rowDist.Datum.(type) {
   324  			case *tree.DDecimal:
   325  				// For some decimal operations, non-local and
   326  				// local computations may differ by the last
   327  				// digit (by 1 ulp).
   328  				decDist := &typedDist.Decimal
   329  				decNonDist := &rowNonDist.Datum.(*tree.DDecimal).Decimal
   330  				strDist = decDist.String()
   331  				strNonDist = decNonDist.String()
   332  				// We first check if they're equivalent, and if
   333  				// not, we check if they're within 1ulp.
   334  				equiv = decDist.Cmp(decNonDist) == 0
   335  				if !equiv {
   336  					if _, err := diffCtx.Sub(decNonDist, decNonDist, decDist); err != nil {
   337  						t.Fatal(err)
   338  					}
   339  					equiv = decNonDist.Coeff.Cmp(bigOne) == 0
   340  				}
   341  			case *tree.DFloat:
   342  				// Float results are highly variable and loss
   343  				// of precision between non-local and local is
   344  				// expected. We reduce the precision specified
   345  				// by floatPrecFmt and compare their string
   346  				// representations.
   347  				floatDist := float64(*typedDist)
   348  				floatNonDist := float64(*rowNonDist.Datum.(*tree.DFloat))
   349  				strDist = fmt.Sprintf(floatPrecFmt, floatDist)
   350  				strNonDist = fmt.Sprintf(floatPrecFmt, floatNonDist)
   351  				// Compare using a relative equality
   352  				// func that isn't dependent on the scale
   353  				// of the number. In addition, ignore any
   354  				// NaNs. Sometimes due to the non-deterministic
   355  				// ordering of distsql, we get a +Inf and
   356  				// Nan result. Both of these changes started
   357  				// happening with the float rand datums
   358  				// were taught about some more adversarial
   359  				// inputs. Since floats by nature have equality
   360  				// problems and I think our algorithms are
   361  				// correct, we need to be slightly more lenient
   362  				// in our float comparisons.
   363  				equiv = almostEqualRelative(floatDist, floatNonDist) || math.IsNaN(floatNonDist) || math.IsNaN(floatDist)
   364  			default:
   365  				// For all other types, a simple string
   366  				// representation comparison will suffice.
   367  				strDist = rowDist.Datum.String()
   368  				strNonDist = rowNonDist.Datum.String()
   369  				equiv = strDist == strNonDist
   370  			}
   371  			if !equiv {
   372  				t.Errorf("different results for column %d\nw/o local stage:   %s\nwith local stage:  %s", i, strDist, strNonDist)
   373  			}
   374  		}
   375  	}
   376  }
   377  
   378  // almostEqualRelative returns whether a and b are close-enough to equal. It
   379  // checks if the two numbers are within a certain relative percentage of
   380  // each other (maxRelDiff), which avoids problems when using "%.3f" as a
   381  // comparison string. This is the "Relative epsilon comparisons" method from:
   382  // https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
   383  func almostEqualRelative(a, b float64) bool {
   384  	if a == b {
   385  		return true
   386  	}
   387  	// Calculate the difference.
   388  	diff := math.Abs(a - b)
   389  	A := math.Abs(a)
   390  	B := math.Abs(b)
   391  	// Find the largest
   392  	largest := A
   393  	if B > A {
   394  		largest = B
   395  	}
   396  	const maxRelDiff = 1e-10
   397  	return diff <= largest*maxRelDiff
   398  }
   399  
   400  // Test that distributing agg functions according to DistAggregationTable
   401  // yields correct results. We're going to run each aggregation as either the
   402  // two-stage process described by the DistAggregationTable or as a single global
   403  // process, and verify that the results are the same.
   404  func TestDistAggregationTable(t *testing.T) {
   405  	defer leaktest.AfterTest(t)()
   406  	const numRows = 100
   407  
   408  	tc := serverutils.StartTestCluster(t, 1, base.TestClusterArgs{})
   409  	defer tc.Stopper().Stop(context.Background())
   410  
   411  	// Create a table with a few columns:
   412  	//  - random integer values from 0 to numRows
   413  	//  - random integer values (with some NULLs)
   414  	//  - random bool value (mostly false)
   415  	//  - random bool value (mostly true)
   416  	//  - random decimals
   417  	//  - random decimals (with some NULLs)
   418  	rng, _ := randutil.NewPseudoRand()
   419  	sqlutils.CreateTable(
   420  		t, tc.ServerConn(0), "t",
   421  		"k INT PRIMARY KEY, int1 INT, int2 INT, bool1 BOOL, bool2 BOOL, dec1 DECIMAL, dec2 DECIMAL, float1 FLOAT, float2 FLOAT, b BYTES",
   422  		numRows,
   423  		func(row int) []tree.Datum {
   424  			return []tree.Datum{
   425  				tree.NewDInt(tree.DInt(row)),
   426  				tree.NewDInt(tree.DInt(rng.Intn(numRows))),
   427  				sqlbase.RandDatum(rng, types.Int, true),
   428  				tree.MakeDBool(tree.DBool(rng.Intn(10) == 0)),
   429  				tree.MakeDBool(tree.DBool(rng.Intn(10) != 0)),
   430  				sqlbase.RandDatum(rng, types.Decimal, false),
   431  				sqlbase.RandDatum(rng, types.Decimal, true),
   432  				sqlbase.RandDatum(rng, types.Float, false),
   433  				sqlbase.RandDatum(rng, types.Float, true),
   434  				tree.NewDBytes(tree.DBytes(randutil.RandBytes(rng, 10))),
   435  			}
   436  		},
   437  	)
   438  
   439  	kvDB := tc.Server(0).DB()
   440  	desc := sqlbase.GetTableDescriptor(kvDB, keys.SystemSQLCodec, "test", "t")
   441  
   442  	for fn, info := range DistAggregationTable {
   443  		if fn == execinfrapb.AggregatorSpec_ANY_NOT_NULL {
   444  			// ANY_NOT_NULL only has a definite result if all rows have the same value
   445  			// on the relevant column; skip testing this trivial case.
   446  			continue
   447  		}
   448  		if fn == execinfrapb.AggregatorSpec_COUNT_ROWS {
   449  			// COUNT_ROWS takes no arguments; skip it in this test.
   450  			continue
   451  		}
   452  		// We're going to test each aggregation function on every column that can be
   453  		// used as input for it.
   454  		foundCol := false
   455  		for colIdx := 1; colIdx < len(desc.Columns); colIdx++ {
   456  			// See if this column works with this function.
   457  			_, _, err := execinfrapb.GetAggregateInfo(fn, desc.Columns[colIdx].Type)
   458  			if err != nil {
   459  				continue
   460  			}
   461  			foundCol = true
   462  			for _, numRows := range []int{5, numRows / 10, numRows / 2, numRows} {
   463  				name := fmt.Sprintf("%s/%s/%d", fn, desc.Columns[colIdx].Name, numRows)
   464  				t.Run(name, func(t *testing.T) {
   465  					checkDistAggregationInfo(
   466  						context.Background(), t, tc.Server(0), desc, colIdx, numRows, fn, info)
   467  				})
   468  			}
   469  		}
   470  		if !foundCol {
   471  			t.Errorf("aggregation function %s was not tested (no suitable column)", fn)
   472  		}
   473  	}
   474  }