github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/physicalplan/aggregator_funcs_test.go (about) 1 // Copyright 2016 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package physicalplan 12 13 import ( 14 "context" 15 "fmt" 16 "math" 17 "math/big" 18 "testing" 19 20 "github.com/cockroachdb/cockroach/pkg/base" 21 "github.com/cockroachdb/cockroach/pkg/keys" 22 "github.com/cockroachdb/cockroach/pkg/kv" 23 "github.com/cockroachdb/cockroach/pkg/sql/distsql" 24 "github.com/cockroachdb/cockroach/pkg/sql/execinfra" 25 "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" 26 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 27 "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" 28 "github.com/cockroachdb/cockroach/pkg/sql/types" 29 "github.com/cockroachdb/cockroach/pkg/testutils/distsqlutils" 30 "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" 31 "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" 32 "github.com/cockroachdb/cockroach/pkg/util/leaktest" 33 "github.com/cockroachdb/cockroach/pkg/util/randutil" 34 "github.com/cockroachdb/cockroach/pkg/util/uuid" 35 ) 36 37 var ( 38 // diffCtx is a decimal context used to perform subtractions between 39 // local and non-local decimal results to check if they are within 40 // 1ulp. Decimals within 1ulp is acceptable for high-precision 41 // decimal calculations. 42 diffCtx = tree.DecimalCtx.WithPrecision(0) 43 // Use to check for 1ulp. 44 bigOne = big.NewInt(1) 45 // floatPrecFmt is the format string with a precision of 3 (after 46 // decimal point) specified for float comparisons. Float aggregation 47 // operations involve unavoidable off-by-last-few-digits errors, which 48 // is expected. 49 floatPrecFmt = "%.3f" 50 ) 51 52 // runTestFlow runs a flow with the given processors and returns the results. 53 // Any errors stop the current test. 54 func runTestFlow( 55 t *testing.T, 56 srv serverutils.TestServerInterface, 57 txn *kv.Txn, 58 procs ...execinfrapb.ProcessorSpec, 59 ) sqlbase.EncDatumRows { 60 distSQLSrv := srv.DistSQLServer().(*distsql.ServerImpl) 61 62 leafInputState := txn.GetLeafTxnInputState(context.Background()) 63 req := execinfrapb.SetupFlowRequest{ 64 Version: execinfra.Version, 65 LeafTxnInputState: &leafInputState, 66 Flow: execinfrapb.FlowSpec{ 67 FlowID: execinfrapb.FlowID{UUID: uuid.MakeV4()}, 68 Processors: procs, 69 }, 70 } 71 72 var rowBuf distsqlutils.RowBuffer 73 74 ctx, flow, err := distSQLSrv.SetupSyncFlow(context.Background(), distSQLSrv.ParentMemoryMonitor, &req, &rowBuf) 75 if err != nil { 76 t.Fatal(err) 77 } 78 if err := flow.Start(ctx, func() {}); err != nil { 79 t.Fatal(err) 80 } 81 flow.Wait() 82 flow.Cleanup(ctx) 83 84 if !rowBuf.ProducerClosed() { 85 t.Errorf("output not closed") 86 } 87 88 var res sqlbase.EncDatumRows 89 for { 90 row, meta := rowBuf.Next() 91 if meta != nil { 92 if meta.LeafTxnFinalState != nil || meta.Metrics != nil { 93 continue 94 } 95 t.Fatalf("unexpected metadata: %v", meta) 96 } 97 if row == nil { 98 break 99 } 100 res = append(res, row) 101 } 102 103 return res 104 } 105 106 // checkDistAggregationInfo tests that a flow with multiple local stages and a 107 // final stage (in accordance with per DistAggregationInfo) gets the same result 108 // with a naive aggregation flow that has a single non-distributed stage. 109 // 110 // Both types of flows are set up and ran against the first numRows of the given 111 // table. We assume the table's first column is the primary key, with values 112 // from 1 to numRows. A non-PK column that works with the function is chosen. 113 func checkDistAggregationInfo( 114 ctx context.Context, 115 t *testing.T, 116 srv serverutils.TestServerInterface, 117 tableDesc *sqlbase.TableDescriptor, 118 colIdx int, 119 numRows int, 120 fn execinfrapb.AggregatorSpec_Func, 121 info DistAggregationInfo, 122 ) { 123 colType := tableDesc.Columns[colIdx].Type 124 125 makeTableReader := func(startPK, endPK int, streamID int) execinfrapb.ProcessorSpec { 126 tr := execinfrapb.TableReaderSpec{ 127 Table: *tableDesc, 128 Spans: make([]execinfrapb.TableReaderSpan, 1), 129 } 130 131 var err error 132 tr.Spans[0].Span.Key, err = sqlbase.TestingMakePrimaryIndexKey(tableDesc, startPK) 133 if err != nil { 134 t.Fatal(err) 135 } 136 tr.Spans[0].Span.EndKey, err = sqlbase.TestingMakePrimaryIndexKey(tableDesc, endPK) 137 if err != nil { 138 t.Fatal(err) 139 } 140 141 return execinfrapb.ProcessorSpec{ 142 Core: execinfrapb.ProcessorCoreUnion{TableReader: &tr}, 143 Post: execinfrapb.PostProcessSpec{ 144 Projection: true, 145 OutputColumns: []uint32{uint32(colIdx)}, 146 }, 147 Output: []execinfrapb.OutputRouterSpec{{ 148 Type: execinfrapb.OutputRouterSpec_PASS_THROUGH, 149 Streams: []execinfrapb.StreamEndpointSpec{ 150 {Type: execinfrapb.StreamEndpointSpec_LOCAL, StreamID: execinfrapb.StreamID(streamID)}, 151 }, 152 }}, 153 } 154 } 155 156 txn := kv.NewTxn(ctx, srv.DB(), srv.NodeID()) 157 158 // First run a flow that aggregates all the rows without any local stages. 159 160 rowsNonDist := runTestFlow( 161 t, srv, txn, 162 makeTableReader(1, numRows+1, 0), 163 execinfrapb.ProcessorSpec{ 164 Input: []execinfrapb.InputSyncSpec{{ 165 Type: execinfrapb.InputSyncSpec_UNORDERED, 166 ColumnTypes: []*types.T{colType}, 167 Streams: []execinfrapb.StreamEndpointSpec{ 168 {Type: execinfrapb.StreamEndpointSpec_LOCAL, StreamID: 0}, 169 }, 170 }}, 171 Core: execinfrapb.ProcessorCoreUnion{Aggregator: &execinfrapb.AggregatorSpec{ 172 Aggregations: []execinfrapb.AggregatorSpec_Aggregation{{Func: fn, ColIdx: []uint32{0}}}, 173 }}, 174 Output: []execinfrapb.OutputRouterSpec{{ 175 Type: execinfrapb.OutputRouterSpec_PASS_THROUGH, 176 Streams: []execinfrapb.StreamEndpointSpec{ 177 {Type: execinfrapb.StreamEndpointSpec_SYNC_RESPONSE}, 178 }, 179 }}, 180 }, 181 ) 182 183 numIntermediary := len(info.LocalStage) 184 numFinal := len(info.FinalStage) 185 for _, finalInfo := range info.FinalStage { 186 if len(finalInfo.LocalIdxs) == 0 { 187 t.Fatalf("final stage must specify input local indices: %#v", info) 188 } 189 for _, localIdx := range finalInfo.LocalIdxs { 190 if localIdx >= uint32(numIntermediary) { 191 t.Fatalf("local index %d out of bounds of local stages: %#v", localIdx, info) 192 } 193 } 194 } 195 196 // Now run a flow with 4 separate table readers, each with its own local 197 // stage, all feeding into a single final stage. 198 199 numParallel := 4 200 201 // The type(s) outputted by the local stage can be different than the input type 202 // (e.g. DECIMAL instead of INT). 203 intermediaryTypes := make([]*types.T, numIntermediary) 204 for i, fn := range info.LocalStage { 205 var err error 206 _, returnTyp, err := execinfrapb.GetAggregateInfo(fn, colType) 207 if err != nil { 208 t.Fatal(err) 209 } 210 intermediaryTypes[i] = returnTyp 211 } 212 213 localAggregations := make([]execinfrapb.AggregatorSpec_Aggregation, numIntermediary) 214 for i, fn := range info.LocalStage { 215 // Local aggregations have the same input. 216 localAggregations[i] = execinfrapb.AggregatorSpec_Aggregation{Func: fn, ColIdx: []uint32{0}} 217 } 218 finalAggregations := make([]execinfrapb.AggregatorSpec_Aggregation, numFinal) 219 for i, finalInfo := range info.FinalStage { 220 // Each local aggregation feeds into a final aggregation. 221 finalAggregations[i] = execinfrapb.AggregatorSpec_Aggregation{ 222 Func: finalInfo.Fn, 223 ColIdx: finalInfo.LocalIdxs, 224 } 225 } 226 227 if numParallel < numRows { 228 numParallel = numRows 229 } 230 finalProc := execinfrapb.ProcessorSpec{ 231 Input: []execinfrapb.InputSyncSpec{{ 232 Type: execinfrapb.InputSyncSpec_UNORDERED, 233 ColumnTypes: intermediaryTypes, 234 }}, 235 Core: execinfrapb.ProcessorCoreUnion{Aggregator: &execinfrapb.AggregatorSpec{ 236 Aggregations: finalAggregations, 237 }}, 238 Output: []execinfrapb.OutputRouterSpec{{ 239 Type: execinfrapb.OutputRouterSpec_PASS_THROUGH, 240 Streams: []execinfrapb.StreamEndpointSpec{ 241 {Type: execinfrapb.StreamEndpointSpec_SYNC_RESPONSE}, 242 }, 243 }}, 244 } 245 246 // The type(s) outputted by the final stage can be different than the 247 // input type (e.g. DECIMAL instead of INT). 248 finalOutputTypes := make([]*types.T, numFinal) 249 // Passed into FinalIndexing as the indices for the IndexedVars inputs 250 // to the post processor. 251 varIdxs := make([]int, numFinal) 252 for i, finalInfo := range info.FinalStage { 253 inputTypes := make([]*types.T, len(finalInfo.LocalIdxs)) 254 for i, localIdx := range finalInfo.LocalIdxs { 255 inputTypes[i] = intermediaryTypes[localIdx] 256 } 257 var err error 258 _, finalOutputTypes[i], err = execinfrapb.GetAggregateInfo(finalInfo.Fn, inputTypes...) 259 if err != nil { 260 t.Fatal(err) 261 } 262 varIdxs[i] = i 263 } 264 265 var procs []execinfrapb.ProcessorSpec 266 for i := 0; i < numParallel; i++ { 267 tr := makeTableReader(1+i*numRows/numParallel, 1+(i+1)*numRows/numParallel, 2*i) 268 agg := execinfrapb.ProcessorSpec{ 269 Input: []execinfrapb.InputSyncSpec{{ 270 Type: execinfrapb.InputSyncSpec_UNORDERED, 271 ColumnTypes: []*types.T{colType}, 272 Streams: []execinfrapb.StreamEndpointSpec{ 273 {Type: execinfrapb.StreamEndpointSpec_LOCAL, StreamID: execinfrapb.StreamID(2 * i)}, 274 }, 275 }}, 276 Core: execinfrapb.ProcessorCoreUnion{Aggregator: &execinfrapb.AggregatorSpec{ 277 Aggregations: localAggregations, 278 }}, 279 Output: []execinfrapb.OutputRouterSpec{{ 280 Type: execinfrapb.OutputRouterSpec_PASS_THROUGH, 281 Streams: []execinfrapb.StreamEndpointSpec{ 282 {Type: execinfrapb.StreamEndpointSpec_LOCAL, StreamID: execinfrapb.StreamID(2*i + 1)}, 283 }, 284 }}, 285 } 286 procs = append(procs, tr, agg) 287 finalProc.Input[0].Streams = append(finalProc.Input[0].Streams, execinfrapb.StreamEndpointSpec{ 288 Type: execinfrapb.StreamEndpointSpec_LOCAL, 289 StreamID: execinfrapb.StreamID(2*i + 1), 290 }) 291 } 292 293 if info.FinalRendering != nil { 294 h := tree.MakeTypesOnlyIndexedVarHelper(finalOutputTypes) 295 renderExpr, err := info.FinalRendering(&h, varIdxs) 296 if err != nil { 297 t.Fatal(err) 298 } 299 var expr execinfrapb.Expression 300 expr, err = MakeExpression(renderExpr, nil, nil) 301 if err != nil { 302 t.Fatal(err) 303 } 304 finalProc.Post.RenderExprs = []execinfrapb.Expression{expr} 305 306 } 307 308 procs = append(procs, finalProc) 309 rowsDist := runTestFlow(t, srv, txn, procs...) 310 311 if len(rowsDist[0]) != len(rowsNonDist[0]) { 312 t.Errorf("different row lengths (dist: %d non-dist: %d)", len(rowsDist[0]), len(rowsNonDist[0])) 313 } else { 314 for i := range rowsDist[0] { 315 rowDist := rowsDist[0][i] 316 rowNonDist := rowsNonDist[0][i] 317 if rowDist.Datum.ResolvedType().Family() != rowNonDist.Datum.ResolvedType().Family() { 318 t.Fatalf("different type for column %d (dist: %s non-dist: %s)", i, rowDist.Datum.ResolvedType(), rowNonDist.Datum.ResolvedType()) 319 } 320 321 var equiv bool 322 var strDist, strNonDist string 323 switch typedDist := rowDist.Datum.(type) { 324 case *tree.DDecimal: 325 // For some decimal operations, non-local and 326 // local computations may differ by the last 327 // digit (by 1 ulp). 328 decDist := &typedDist.Decimal 329 decNonDist := &rowNonDist.Datum.(*tree.DDecimal).Decimal 330 strDist = decDist.String() 331 strNonDist = decNonDist.String() 332 // We first check if they're equivalent, and if 333 // not, we check if they're within 1ulp. 334 equiv = decDist.Cmp(decNonDist) == 0 335 if !equiv { 336 if _, err := diffCtx.Sub(decNonDist, decNonDist, decDist); err != nil { 337 t.Fatal(err) 338 } 339 equiv = decNonDist.Coeff.Cmp(bigOne) == 0 340 } 341 case *tree.DFloat: 342 // Float results are highly variable and loss 343 // of precision between non-local and local is 344 // expected. We reduce the precision specified 345 // by floatPrecFmt and compare their string 346 // representations. 347 floatDist := float64(*typedDist) 348 floatNonDist := float64(*rowNonDist.Datum.(*tree.DFloat)) 349 strDist = fmt.Sprintf(floatPrecFmt, floatDist) 350 strNonDist = fmt.Sprintf(floatPrecFmt, floatNonDist) 351 // Compare using a relative equality 352 // func that isn't dependent on the scale 353 // of the number. In addition, ignore any 354 // NaNs. Sometimes due to the non-deterministic 355 // ordering of distsql, we get a +Inf and 356 // Nan result. Both of these changes started 357 // happening with the float rand datums 358 // were taught about some more adversarial 359 // inputs. Since floats by nature have equality 360 // problems and I think our algorithms are 361 // correct, we need to be slightly more lenient 362 // in our float comparisons. 363 equiv = almostEqualRelative(floatDist, floatNonDist) || math.IsNaN(floatNonDist) || math.IsNaN(floatDist) 364 default: 365 // For all other types, a simple string 366 // representation comparison will suffice. 367 strDist = rowDist.Datum.String() 368 strNonDist = rowNonDist.Datum.String() 369 equiv = strDist == strNonDist 370 } 371 if !equiv { 372 t.Errorf("different results for column %d\nw/o local stage: %s\nwith local stage: %s", i, strDist, strNonDist) 373 } 374 } 375 } 376 } 377 378 // almostEqualRelative returns whether a and b are close-enough to equal. It 379 // checks if the two numbers are within a certain relative percentage of 380 // each other (maxRelDiff), which avoids problems when using "%.3f" as a 381 // comparison string. This is the "Relative epsilon comparisons" method from: 382 // https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ 383 func almostEqualRelative(a, b float64) bool { 384 if a == b { 385 return true 386 } 387 // Calculate the difference. 388 diff := math.Abs(a - b) 389 A := math.Abs(a) 390 B := math.Abs(b) 391 // Find the largest 392 largest := A 393 if B > A { 394 largest = B 395 } 396 const maxRelDiff = 1e-10 397 return diff <= largest*maxRelDiff 398 } 399 400 // Test that distributing agg functions according to DistAggregationTable 401 // yields correct results. We're going to run each aggregation as either the 402 // two-stage process described by the DistAggregationTable or as a single global 403 // process, and verify that the results are the same. 404 func TestDistAggregationTable(t *testing.T) { 405 defer leaktest.AfterTest(t)() 406 const numRows = 100 407 408 tc := serverutils.StartTestCluster(t, 1, base.TestClusterArgs{}) 409 defer tc.Stopper().Stop(context.Background()) 410 411 // Create a table with a few columns: 412 // - random integer values from 0 to numRows 413 // - random integer values (with some NULLs) 414 // - random bool value (mostly false) 415 // - random bool value (mostly true) 416 // - random decimals 417 // - random decimals (with some NULLs) 418 rng, _ := randutil.NewPseudoRand() 419 sqlutils.CreateTable( 420 t, tc.ServerConn(0), "t", 421 "k INT PRIMARY KEY, int1 INT, int2 INT, bool1 BOOL, bool2 BOOL, dec1 DECIMAL, dec2 DECIMAL, float1 FLOAT, float2 FLOAT, b BYTES", 422 numRows, 423 func(row int) []tree.Datum { 424 return []tree.Datum{ 425 tree.NewDInt(tree.DInt(row)), 426 tree.NewDInt(tree.DInt(rng.Intn(numRows))), 427 sqlbase.RandDatum(rng, types.Int, true), 428 tree.MakeDBool(tree.DBool(rng.Intn(10) == 0)), 429 tree.MakeDBool(tree.DBool(rng.Intn(10) != 0)), 430 sqlbase.RandDatum(rng, types.Decimal, false), 431 sqlbase.RandDatum(rng, types.Decimal, true), 432 sqlbase.RandDatum(rng, types.Float, false), 433 sqlbase.RandDatum(rng, types.Float, true), 434 tree.NewDBytes(tree.DBytes(randutil.RandBytes(rng, 10))), 435 } 436 }, 437 ) 438 439 kvDB := tc.Server(0).DB() 440 desc := sqlbase.GetTableDescriptor(kvDB, keys.SystemSQLCodec, "test", "t") 441 442 for fn, info := range DistAggregationTable { 443 if fn == execinfrapb.AggregatorSpec_ANY_NOT_NULL { 444 // ANY_NOT_NULL only has a definite result if all rows have the same value 445 // on the relevant column; skip testing this trivial case. 446 continue 447 } 448 if fn == execinfrapb.AggregatorSpec_COUNT_ROWS { 449 // COUNT_ROWS takes no arguments; skip it in this test. 450 continue 451 } 452 // We're going to test each aggregation function on every column that can be 453 // used as input for it. 454 foundCol := false 455 for colIdx := 1; colIdx < len(desc.Columns); colIdx++ { 456 // See if this column works with this function. 457 _, _, err := execinfrapb.GetAggregateInfo(fn, desc.Columns[colIdx].Type) 458 if err != nil { 459 continue 460 } 461 foundCol = true 462 for _, numRows := range []int{5, numRows / 10, numRows / 2, numRows} { 463 name := fmt.Sprintf("%s/%s/%d", fn, desc.Columns[colIdx].Name, numRows) 464 t.Run(name, func(t *testing.T) { 465 checkDistAggregationInfo( 466 context.Background(), t, tc.Server(0), desc, colIdx, numRows, fn, info) 467 }) 468 } 469 } 470 if !foundCol { 471 t.Errorf("aggregation function %s was not tested (no suitable column)", fn) 472 } 473 } 474 }