github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/norm/decorrelate_funcs.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package norm
    12  
    13  import (
    14  	"github.com/cockroachdb/cockroach/pkg/sql/opt"
    15  	"github.com/cockroachdb/cockroach/pkg/sql/opt/cat"
    16  	"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
    17  	"github.com/cockroachdb/cockroach/pkg/sql/opt/props"
    18  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    19  	"github.com/cockroachdb/cockroach/pkg/util/log"
    20  	"github.com/cockroachdb/errors"
    21  )
    22  
    23  // HasHoistableSubquery returns true if the given scalar expression contains a
    24  // subquery within its subtree that has at least one outer column, and if that
    25  // subquery needs to be hoisted up into its parent query as part of query
    26  // decorrelation.
    27  func (c *CustomFuncs) HasHoistableSubquery(scalar opt.ScalarExpr) bool {
    28  	// Shortcut if the scalar has properties associated with it.
    29  	if scalarPropsExpr, ok := scalar.(memo.ScalarPropsExpr); ok {
    30  		// Don't bother traversing the expression tree if there is no subquery.
    31  		scalarProps := scalarPropsExpr.ScalarProps()
    32  		if !scalarProps.HasSubquery {
    33  			return false
    34  		}
    35  
    36  		// Lazily calculate and store the HasHoistableSubquery value.
    37  		if !scalarProps.IsAvailable(props.HasHoistableSubquery) {
    38  			scalarProps.Rule.HasHoistableSubquery = c.deriveHasHoistableSubquery(scalar)
    39  			scalarProps.SetAvailable(props.HasHoistableSubquery)
    40  		}
    41  		return scalarProps.Rule.HasHoistableSubquery
    42  	}
    43  
    44  	// Otherwise fall back on full traversal of subtree.
    45  	return c.deriveHasHoistableSubquery(scalar)
    46  }
    47  
    48  func (c *CustomFuncs) deriveHasHoistableSubquery(scalar opt.ScalarExpr) bool {
    49  	switch t := scalar.(type) {
    50  	case *memo.SubqueryExpr:
    51  		return !t.Input.Relational().OuterCols.Empty()
    52  
    53  	case *memo.ExistsExpr:
    54  		return !t.Input.Relational().OuterCols.Empty()
    55  
    56  	case *memo.ArrayFlattenExpr:
    57  		return !t.Input.Relational().OuterCols.Empty()
    58  
    59  	case *memo.AnyExpr:
    60  		// Don't hoist Any when only its Scalar operand is correlated, because it
    61  		// executes much slower. It's better to cache the results of the constant
    62  		// subquery in this case. Note that if an Any is at the top-level of a
    63  		// WHERE clause, it will be transformed to an Exists operator, so this case
    64  		// only occurs when the Any is nested, in a projection, etc.
    65  		return !t.Input.Relational().OuterCols.Empty()
    66  	}
    67  
    68  	// If HasHoistableSubquery is true for any child, then it's true for this
    69  	// expression as well. The exception is Case/If branches that have side
    70  	// effects. These can only be executed if the branch test evaluates to true,
    71  	// and so it's not possible to hoist out subqueries, since they would then be
    72  	// evaluated when they shouldn't be.
    73  	for i, n := 0, scalar.ChildCount(); i < n; i++ {
    74  		child := scalar.Child(i).(opt.ScalarExpr)
    75  		if c.deriveHasHoistableSubquery(child) {
    76  			var sharedProps props.Shared
    77  			hasHoistableSubquery := true
    78  
    79  			// Consider CASE WHEN and ELSE branches:
    80  			//   (Case
    81  			//     $input:*
    82  			//     (When $cond1:* $branch1:*)  # optional
    83  			//     (When $cond2:* $branch2:*)  # optional
    84  			//     $else:*                     # optional
    85  			//   )
    86  			switch t := scalar.(type) {
    87  			case *memo.CaseExpr:
    88  				// Determine whether this is the Else child.
    89  				if child == t.OrElse {
    90  					memo.BuildSharedProps(child, &sharedProps)
    91  					hasHoistableSubquery = !sharedProps.CanHaveSideEffects
    92  				}
    93  
    94  			case *memo.WhenExpr:
    95  				if child == t.Value {
    96  					memo.BuildSharedProps(child, &sharedProps)
    97  					hasHoistableSubquery = !sharedProps.CanHaveSideEffects
    98  				}
    99  
   100  			case *memo.IfErrExpr:
   101  				// Determine whether this is the Else child. Checking this how the
   102  				// other branches do is tricky because it's a list, but we know that
   103  				// it's at position 1.
   104  				if i == 1 {
   105  					memo.BuildSharedProps(child, &sharedProps)
   106  					hasHoistableSubquery = !sharedProps.CanHaveSideEffects
   107  				}
   108  			}
   109  
   110  			if hasHoistableSubquery {
   111  				return true
   112  			}
   113  		}
   114  	}
   115  	return false
   116  }
   117  
   118  // HoistSelectSubquery searches the Select operator's filter for correlated
   119  // subqueries. Any found queries are hoisted into LeftJoinApply or
   120  // InnerJoinApply operators, depending on subquery cardinality:
   121  //
   122  //   SELECT * FROM xy WHERE (SELECT u FROM uv WHERE u=x LIMIT 1) IS NULL
   123  //   =>
   124  //   SELECT xy.*
   125  //   FROM xy
   126  //   LEFT JOIN LATERAL (SELECT u FROM uv WHERE u=x LIMIT 1)
   127  //   ON True
   128  //   WHERE u IS NULL
   129  //
   130  func (c *CustomFuncs) HoistSelectSubquery(
   131  	input memo.RelExpr, filters memo.FiltersExpr,
   132  ) memo.RelExpr {
   133  	newFilters := make(memo.FiltersExpr, 0, len(filters))
   134  
   135  	var hoister subqueryHoister
   136  	hoister.init(c, input)
   137  	for i := range filters {
   138  		item := &filters[i]
   139  		if item.ScalarProps().Rule.HasHoistableSubquery {
   140  			replaced := hoister.hoistAll(item.Condition)
   141  			if replaced.Op() != opt.TrueOp {
   142  				newFilters = append(newFilters, c.f.ConstructFiltersItem(replaced))
   143  			}
   144  		} else {
   145  			newFilters = append(newFilters, *item)
   146  		}
   147  	}
   148  
   149  	sel := c.f.ConstructSelect(hoister.input(), newFilters)
   150  	return c.f.ConstructProject(sel, memo.EmptyProjectionsExpr, c.OutputCols(input))
   151  }
   152  
   153  // HoistProjectSubquery searches the Project operator's projections for
   154  // correlated subqueries. Any found queries are hoisted into LeftJoinApply
   155  // or InnerJoinApply operators, depending on subquery cardinality:
   156  //
   157  //   SELECT (SELECT max(u) FROM uv WHERE u=x) AS max FROM xy
   158  //   =>
   159  //   SELECT max
   160  //   FROM xy
   161  //   INNER JOIN LATERAL (SELECT max(u) FROM uv WHERE u=x)
   162  //   ON True
   163  //
   164  func (c *CustomFuncs) HoistProjectSubquery(
   165  	input memo.RelExpr, projections memo.ProjectionsExpr, passthrough opt.ColSet,
   166  ) memo.RelExpr {
   167  	newProjections := make(memo.ProjectionsExpr, 0, len(projections))
   168  
   169  	var hoister subqueryHoister
   170  	hoister.init(c, input)
   171  	for i := range projections {
   172  		item := &projections[i]
   173  		if item.ScalarProps().Rule.HasHoistableSubquery {
   174  			replaced := hoister.hoistAll(item.Element)
   175  			newProjections = append(newProjections, c.f.ConstructProjectionsItem(replaced, item.Col))
   176  		} else {
   177  			newProjections = append(newProjections, *item)
   178  		}
   179  	}
   180  
   181  	return c.f.ConstructProject(hoister.input(), newProjections, passthrough)
   182  }
   183  
   184  // HoistJoinSubquery searches the Join operator's filter for correlated
   185  // subqueries. Any found queries are hoisted into LeftJoinApply or
   186  // InnerJoinApply operators, depending on subquery cardinality:
   187  //
   188  //   SELECT y, z
   189  //   FROM xy
   190  //   FULL JOIN yz
   191  //   ON (SELECT u FROM uv WHERE u=x LIMIT 1) IS NULL
   192  //   =>
   193  //   SELECT y, z
   194  //   FROM xy
   195  //   FULL JOIN LATERAL
   196  //   (
   197  //     SELECT *
   198  //     FROM yz
   199  //     LEFT JOIN LATERAL (SELECT u FROM uv WHERE u=x LIMIT 1)
   200  //     ON True
   201  //   )
   202  //   ON u IS NULL
   203  //
   204  func (c *CustomFuncs) HoistJoinSubquery(
   205  	op opt.Operator, left, right memo.RelExpr, on memo.FiltersExpr, private *memo.JoinPrivate,
   206  ) memo.RelExpr {
   207  	newFilters := make(memo.FiltersExpr, 0, len(on))
   208  
   209  	var hoister subqueryHoister
   210  	hoister.init(c, right)
   211  	for i := range on {
   212  		item := &on[i]
   213  		if item.ScalarProps().Rule.HasHoistableSubquery {
   214  			replaced := hoister.hoistAll(item.Condition)
   215  			if replaced.Op() != opt.TrueOp {
   216  				newFilters = append(newFilters, c.f.ConstructFiltersItem(replaced))
   217  			}
   218  		} else {
   219  			newFilters = append(newFilters, *item)
   220  		}
   221  	}
   222  
   223  	join := c.ConstructApplyJoin(op, left, hoister.input(), newFilters, private)
   224  	passthrough := c.OutputCols(left).Union(c.OutputCols(right))
   225  	return c.f.ConstructProject(join, memo.EmptyProjectionsExpr, passthrough)
   226  }
   227  
   228  // HoistValuesSubquery searches the Values operator's projections for correlated
   229  // subqueries. Any found queries are hoisted into LeftJoinApply or
   230  // InnerJoinApply operators, depending on subquery cardinality:
   231  //
   232  //   SELECT (VALUES (SELECT u FROM uv WHERE u=x LIMIT 1)) FROM xy
   233  //   =>
   234  //   SELECT
   235  //   (
   236  //     SELECT vals.*
   237  //     FROM (VALUES ())
   238  //     LEFT JOIN LATERAL (SELECT u FROM uv WHERE u=x LIMIT 1)
   239  //     ON True
   240  //     INNER JOIN LATERAL (VALUES (u)) vals
   241  //     ON True
   242  //   )
   243  //   FROM xy
   244  //
   245  // The dummy VALUES clause with a singleton empty row is added to the tree in
   246  // order to use the hoister, which requires an initial input query. While a
   247  // right join would be slightly better here, this is such a fringe case that
   248  // it's not worth the extra code complication.
   249  func (c *CustomFuncs) HoistValuesSubquery(
   250  	rows memo.ScalarListExpr, private *memo.ValuesPrivate,
   251  ) memo.RelExpr {
   252  	newRows := make(memo.ScalarListExpr, 0, len(rows))
   253  
   254  	var hoister subqueryHoister
   255  	hoister.init(c, c.ConstructNoColsRow())
   256  	for _, item := range rows {
   257  		newRows = append(newRows, hoister.hoistAll(item))
   258  	}
   259  
   260  	values := c.f.ConstructValues(newRows, &memo.ValuesPrivate{
   261  		Cols: private.Cols,
   262  		ID:   c.f.Metadata().NextUniqueID(),
   263  	})
   264  	join := c.f.ConstructInnerJoinApply(hoister.input(), values, memo.TrueFilter, memo.EmptyJoinPrivate)
   265  	outCols := values.Relational().OutputCols
   266  	return c.f.ConstructProject(join, memo.EmptyProjectionsExpr, outCols)
   267  }
   268  
   269  // HoistProjectSetSubquery searches the ProjectSet operator's functions for
   270  // correlated subqueries. Any found queries are hoisted into LeftJoinApply or
   271  // InnerJoinApply operators, depending on subquery cardinality:
   272  //
   273  //   SELECT generate_series
   274  //   FROM xy
   275  //   INNER JOIN LATERAL ROWS FROM
   276  //   (
   277  //     generate_series(1, (SELECT v FROM uv WHERE u=x))
   278  //   )
   279  //   =>
   280  //   SELECT generate_series
   281  //   FROM xy
   282  //   ROWS FROM
   283  //   (
   284  //     SELECT generate_series
   285  //     FROM (VALUES ())
   286  //     LEFT JOIN LATERAL (SELECT v FROM uv WHERE u=x)
   287  //     ON True
   288  //     INNER JOIN LATERAL ROWS FROM (generate_series(1, v))
   289  //     ON True
   290  //   )
   291  //
   292  func (c *CustomFuncs) HoistProjectSetSubquery(input memo.RelExpr, zip memo.ZipExpr) memo.RelExpr {
   293  	newZip := make(memo.ZipExpr, 0, len(zip))
   294  
   295  	var hoister subqueryHoister
   296  	hoister.init(c, input)
   297  	for i := range zip {
   298  		item := &zip[i]
   299  		if item.ScalarProps().Rule.HasHoistableSubquery {
   300  			replaced := hoister.hoistAll(item.Fn)
   301  			newZip = append(newZip, c.f.ConstructZipItem(replaced, item.Cols))
   302  		} else {
   303  			newZip = append(newZip, *item)
   304  		}
   305  	}
   306  
   307  	// The process of hoisting will introduce additional columns, so we introduce
   308  	// a projection to not include those in the output.
   309  	outputCols := c.OutputCols(input).Union(zip.OutputCols())
   310  
   311  	projectSet := c.f.ConstructProjectSet(hoister.input(), newZip)
   312  	return c.f.ConstructProject(projectSet, memo.EmptyProjectionsExpr, outputCols)
   313  }
   314  
   315  // ConstructNonApplyJoin constructs the non-apply join operator that corresponds
   316  // to the given join operator type.
   317  func (c *CustomFuncs) ConstructNonApplyJoin(
   318  	joinOp opt.Operator, left, right memo.RelExpr, on memo.FiltersExpr, private *memo.JoinPrivate,
   319  ) memo.RelExpr {
   320  	switch joinOp {
   321  	case opt.InnerJoinOp, opt.InnerJoinApplyOp:
   322  		return c.f.ConstructInnerJoin(left, right, on, private)
   323  	case opt.LeftJoinOp, opt.LeftJoinApplyOp:
   324  		return c.f.ConstructLeftJoin(left, right, on, private)
   325  	case opt.SemiJoinOp, opt.SemiJoinApplyOp:
   326  		return c.f.ConstructSemiJoin(left, right, on, private)
   327  	case opt.AntiJoinOp, opt.AntiJoinApplyOp:
   328  		return c.f.ConstructAntiJoin(left, right, on, private)
   329  	}
   330  	panic(errors.AssertionFailedf("unexpected join operator: %v", log.Safe(joinOp)))
   331  }
   332  
   333  // ConstructApplyJoin constructs the apply join operator that corresponds
   334  // to the given join operator type.
   335  func (c *CustomFuncs) ConstructApplyJoin(
   336  	joinOp opt.Operator, left, right memo.RelExpr, on memo.FiltersExpr, private *memo.JoinPrivate,
   337  ) memo.RelExpr {
   338  	switch joinOp {
   339  	case opt.InnerJoinOp, opt.InnerJoinApplyOp:
   340  		return c.f.ConstructInnerJoinApply(left, right, on, private)
   341  	case opt.LeftJoinOp, opt.LeftJoinApplyOp:
   342  		return c.f.ConstructLeftJoinApply(left, right, on, private)
   343  	case opt.SemiJoinOp, opt.SemiJoinApplyOp:
   344  		return c.f.ConstructSemiJoinApply(left, right, on, private)
   345  	case opt.AntiJoinOp, opt.AntiJoinApplyOp:
   346  		return c.f.ConstructAntiJoinApply(left, right, on, private)
   347  	}
   348  	panic(errors.AssertionFailedf("unexpected join operator: %v", log.Safe(joinOp)))
   349  }
   350  
   351  // EnsureKey finds the shortest strong key for the input expression. If no
   352  // strong key exists and the input expression is a scan, EnsureKey returns a new
   353  // Scan with the preexisting primary key for the table. If the input is not a
   354  // Scan, EnsureKey wraps the input in an Ordinality operator, which provides a
   355  // key column by uniquely numbering the rows. EnsureKey returns the input
   356  // expression (perhaps augmented with a key column(s) or wrapped by Ordinality).
   357  func (c *CustomFuncs) EnsureKey(in memo.RelExpr) memo.RelExpr {
   358  	_, ok := c.CandidateKey(in)
   359  	if ok {
   360  		return in
   361  	}
   362  
   363  	switch t := in.(type) {
   364  	case *memo.ScanExpr:
   365  		// Add primary key columns if this is a non-virtual table.
   366  		private := t.ScanPrivate
   367  		tableID := private.Table
   368  		table := c.f.Metadata().Table(tableID)
   369  		if !table.IsVirtualTable() {
   370  			keyCols := c.f.Metadata().TableMeta(tableID).IndexKeyColumns(cat.PrimaryIndex)
   371  			private.Cols = private.Cols.Union(keyCols)
   372  			return c.f.ConstructScan(&private)
   373  		}
   374  	}
   375  
   376  	colID := c.f.Metadata().AddColumn("rownum", types.Int)
   377  	private := memo.OrdinalityPrivate{ColID: colID}
   378  	return c.f.ConstructOrdinality(in, &private)
   379  }
   380  
   381  // KeyCols returns a column set consisting of the columns that make up the
   382  // candidate key for the input expression (a key must be present).
   383  func (c *CustomFuncs) KeyCols(in memo.RelExpr) opt.ColSet {
   384  	keyCols, ok := c.CandidateKey(in)
   385  	if !ok {
   386  		panic(errors.AssertionFailedf("expected expression to have key"))
   387  	}
   388  	return keyCols
   389  }
   390  
   391  // NonKeyCols returns a column set consisting of the output columns of the given
   392  // input, minus the columns that make up its candidate key (which it must have).
   393  func (c *CustomFuncs) NonKeyCols(in memo.RelExpr) opt.ColSet {
   394  	keyCols, ok := c.CandidateKey(in)
   395  	if !ok {
   396  		panic(errors.AssertionFailedf("expected expression to have key"))
   397  	}
   398  	return c.OutputCols(in).Difference(keyCols)
   399  }
   400  
   401  // MakeAggCols constructs a new Aggregations operator containing an aggregate
   402  // function of the given operator type for each of column in the given set. For
   403  // example, for ConstAggOp and columns (1,2), this expression is returned:
   404  //
   405  //   (Aggregations
   406  //     [(ConstAgg (Variable 1)) (ConstAgg (Variable 2))]
   407  //     [1,2]
   408  //   )
   409  //
   410  func (c *CustomFuncs) MakeAggCols(aggOp opt.Operator, cols opt.ColSet) memo.AggregationsExpr {
   411  	colsLen := cols.Len()
   412  	aggs := make(memo.AggregationsExpr, colsLen)
   413  	c.makeAggCols(aggOp, cols, aggs)
   414  	return aggs
   415  }
   416  
   417  // MakeAggCols2 is similar to MakeAggCols, except that it allows two different
   418  // sets of aggregate functions to be added to the resulting Aggregations
   419  // operator, with one set appended to the other, like this:
   420  //
   421  //   (Aggregations
   422  //     [(ConstAgg (Variable 1)) (ConstAgg (Variable 2)) (FirstAgg (Variable 3))]
   423  //     [1,2,3]
   424  //   )
   425  //
   426  func (c *CustomFuncs) MakeAggCols2(
   427  	aggOp opt.Operator, cols opt.ColSet, aggOp2 opt.Operator, cols2 opt.ColSet,
   428  ) memo.AggregationsExpr {
   429  	colsLen := cols.Len()
   430  	aggs := make(memo.AggregationsExpr, colsLen+cols2.Len())
   431  	c.makeAggCols(aggOp, cols, aggs)
   432  	c.makeAggCols(aggOp2, cols2, aggs[colsLen:])
   433  	return aggs
   434  }
   435  
   436  // EnsureCanaryCol checks whether an aggregation which cannot ignore nulls exists.
   437  // If one does, it then checks if there are any non-null columns in the input.
   438  // If there is not one, it synthesizes a new True constant column that is
   439  // not-null. This becomes a kind of "canary" column that other expressions can
   440  // inspect, since any null value in this column indicates that the row was
   441  // added by an outer join as part of null extending.
   442  //
   443  // EnsureCanaryCol returns the input expression, possibly wrapped in a new
   444  // Project if a new column was synthesized.
   445  //
   446  // See the TryDecorrelateScalarGroupBy rule comment for more details.
   447  func (c *CustomFuncs) EnsureCanaryCol(in memo.RelExpr, aggs memo.AggregationsExpr) opt.ColumnID {
   448  	for i := range aggs {
   449  		if !opt.AggregateIgnoresNulls(aggs[i].Agg.Op()) {
   450  			// Look for an existing not null column that is not projected by a
   451  			// passthrough aggregate like ConstAgg.
   452  			id, ok := in.Relational().NotNullCols.Next(0)
   453  			if ok && !aggs.OutputCols().Contains(id) {
   454  				return id
   455  			}
   456  
   457  			// Synthesize a new column ID.
   458  			return c.f.Metadata().AddColumn("canary", types.Bool)
   459  		}
   460  	}
   461  	return 0
   462  }
   463  
   464  // EnsureCanary makes sure that if canaryCol is set, it is projected by the
   465  // input expression.
   466  //
   467  // See the TryDecorrelateScalarGroupBy rule comment for more details.
   468  func (c *CustomFuncs) EnsureCanary(in memo.RelExpr, canaryCol opt.ColumnID) memo.RelExpr {
   469  	if canaryCol == 0 || c.OutputCols(in).Contains(canaryCol) {
   470  		return in
   471  	}
   472  	result := c.ProjectExtraCol(in, c.f.ConstructTrue(), canaryCol)
   473  	return result
   474  }
   475  
   476  // CanaryColSet returns a singleton set containing the canary column if set,
   477  // otherwise the empty set.
   478  func (c *CustomFuncs) CanaryColSet(canaryCol opt.ColumnID) opt.ColSet {
   479  	var colSet opt.ColSet
   480  	if canaryCol != 0 {
   481  		colSet.Add(canaryCol)
   482  	}
   483  	return colSet
   484  }
   485  
   486  // AggsCanBeDecorrelated returns true if every aggregate satisfies one of the
   487  // following conditions:
   488  //
   489  //   * It is CountRows (because it will be translated into Count),
   490  //   * It ignores nulls (because nothing extra must be done for it)
   491  //   * It gives NULL on no input (because this is how we translate non-null
   492  //     ignoring aggregates)
   493  //
   494  // TODO(justin): we can lift the third condition if we have a function that
   495  // gives the correct "on empty" value for a given aggregate.
   496  func (c *CustomFuncs) AggsCanBeDecorrelated(aggs memo.AggregationsExpr) bool {
   497  	for i := range aggs {
   498  		agg := aggs[i].Agg
   499  		op := agg.Op()
   500  		if op == opt.AggFilterOp || op == opt.AggDistinctOp {
   501  			// TODO(radu): investigate if we can do better here
   502  			return false
   503  		}
   504  		if !(op == opt.CountRowsOp || opt.AggregateIgnoresNulls(op) || opt.AggregateIsNullOnEmpty(op)) {
   505  			return false
   506  		}
   507  	}
   508  
   509  	return true
   510  }
   511  
   512  // constructCanaryChecker returns a CASE expression which disambiguates an
   513  // aggregation over a left join having received a NULL column because there
   514  // were no matches on the right side of the join, and having received a NULL
   515  // column because a NULL column was matched against.
   516  func (c *CustomFuncs) constructCanaryChecker(
   517  	aggCanaryVar opt.ScalarExpr, inputCol opt.ColumnID,
   518  ) opt.ScalarExpr {
   519  	return c.f.ConstructCase(
   520  		memo.TrueSingleton,
   521  		memo.ScalarListExpr{
   522  			c.f.ConstructWhen(
   523  				c.f.ConstructIsNot(aggCanaryVar, memo.NullSingleton),
   524  				c.f.ConstructVariable(inputCol),
   525  			),
   526  		},
   527  		memo.NullSingleton,
   528  	)
   529  }
   530  
   531  // TranslateNonIgnoreAggs checks if any of the aggregates being decorrelated
   532  // are unable to ignore nulls. If that is the case, it inserts projections
   533  // which check a "canary" aggregation that determines if an expression actually
   534  // had any things grouped into it or not.
   535  func (c *CustomFuncs) TranslateNonIgnoreAggs(
   536  	newIn memo.RelExpr,
   537  	newAggs memo.AggregationsExpr,
   538  	oldIn memo.RelExpr,
   539  	oldAggs memo.AggregationsExpr,
   540  	canaryCol opt.ColumnID,
   541  ) memo.RelExpr {
   542  	var aggCanaryVar opt.ScalarExpr
   543  	passthrough := c.OutputCols(newIn).Copy()
   544  	passthrough.Remove(canaryCol)
   545  
   546  	var projections memo.ProjectionsExpr
   547  	for i := range newAggs {
   548  		agg := newAggs[i].Agg
   549  		if !opt.AggregateIgnoresNulls(agg.Op()) {
   550  			if aggCanaryVar == nil {
   551  				if canaryCol == 0 {
   552  					id, ok := oldIn.Relational().NotNullCols.Next(0)
   553  					if !ok {
   554  						panic(errors.AssertionFailedf("expected input expression to have not-null column"))
   555  					}
   556  					canaryCol = id
   557  				}
   558  				aggCanaryVar = c.f.ConstructVariable(canaryCol)
   559  			}
   560  
   561  			if !opt.AggregateIsNullOnEmpty(agg.Op()) {
   562  				// If this gets triggered we need to modify constructCanaryChecker to
   563  				// have a special "on-empty" value. This shouldn't get triggered
   564  				// because as of writing the only operation that is false for both
   565  				// AggregateIgnoresNulls and AggregateIsNullOnEmpty is CountRows, and
   566  				// we translate that into Count.
   567  				// TestAllAggsIgnoreNullsOrNullOnEmpty verifies that this assumption is
   568  				// true.
   569  				panic(errors.AssertionFailedf("can't decorrelate with aggregate %s", log.Safe(agg.Op())))
   570  			}
   571  
   572  			if projections == nil {
   573  				projections = make(memo.ProjectionsExpr, 0, len(newAggs)-i)
   574  			}
   575  			projections = append(projections, c.f.ConstructProjectionsItem(
   576  				c.constructCanaryChecker(aggCanaryVar, newAggs[i].Col),
   577  				oldAggs[i].Col,
   578  			))
   579  			passthrough.Remove(newAggs[i].Col)
   580  		}
   581  	}
   582  
   583  	if projections == nil {
   584  		return newIn
   585  	}
   586  	return c.f.ConstructProject(newIn, projections, passthrough)
   587  }
   588  
   589  // EnsureAggsCanIgnoreNulls scans the aggregate list to aggregation functions that
   590  // don't ignore nulls but can be remapped so that they do:
   591  //   - CountRows functions are are converted to Count functions that operate
   592  //     over a not-null column from the given input expression. The
   593  //     EnsureNotNullIfNeeded method should already have been called in order
   594  //     to guarantee such a column exists.
   595  //   - ConstAgg is remapped to ConstNotNullAgg.
   596  //   - Other aggregates that can use a canary column to detect nulls.
   597  //
   598  // See the TryDecorrelateScalarGroupBy rule comment for more details.
   599  func (c *CustomFuncs) EnsureAggsCanIgnoreNulls(
   600  	in memo.RelExpr, aggs memo.AggregationsExpr,
   601  ) memo.AggregationsExpr {
   602  	var newAggs memo.AggregationsExpr
   603  	for i := range aggs {
   604  		newAgg := aggs[i].Agg
   605  		newCol := aggs[i].Col
   606  
   607  		switch t := newAgg.(type) {
   608  		case *memo.ConstAggExpr:
   609  			// Translate ConstAgg(...) to ConstNotNullAgg(...).
   610  			newAgg = c.f.ConstructConstNotNullAgg(t.Input)
   611  
   612  		case *memo.CountRowsExpr:
   613  			// Translate CountRows() to Count(notNullCol).
   614  			id, ok := in.Relational().NotNullCols.Next(0)
   615  			if !ok {
   616  				panic(errors.AssertionFailedf("expected input expression to have not-null column"))
   617  			}
   618  			notNullColID := id
   619  			newAgg = c.f.ConstructCount(c.f.ConstructVariable(notNullColID))
   620  
   621  		default:
   622  			if !opt.AggregateIgnoresNulls(t.Op()) {
   623  				// Allocate id for new intermediate agg column. The column will get
   624  				// mapped back to the original id after the grouping (by the
   625  				// TranslateNonIgnoreAggs method).
   626  				md := c.f.Metadata()
   627  				colMeta := md.ColumnMeta(newCol)
   628  				newCol = md.AddColumn(colMeta.Alias, colMeta.Type)
   629  			}
   630  		}
   631  		if newAggs == nil {
   632  			if newAgg != aggs[i].Agg || newCol != aggs[i].Col {
   633  				newAggs = make(memo.AggregationsExpr, len(aggs))
   634  				copy(newAggs, aggs[:i])
   635  			}
   636  		}
   637  		if newAggs != nil {
   638  			newAggs[i] = c.f.ConstructAggregationsItem(newAgg, newCol)
   639  		}
   640  	}
   641  	if newAggs == nil {
   642  		// No changes.
   643  		return aggs
   644  	}
   645  	return newAggs
   646  }
   647  
   648  // AddColsToPartition unions the given set of columns with a window private's
   649  // partition columns.
   650  func (c *CustomFuncs) AddColsToPartition(
   651  	priv *memo.WindowPrivate, cols opt.ColSet,
   652  ) *memo.WindowPrivate {
   653  	cpy := *priv
   654  	cpy.Partition = cpy.Partition.Union(cols)
   655  	return &cpy
   656  }
   657  
   658  // ConstructAnyCondition builds an expression that compares the given scalar
   659  // expression with the first (and only) column of the input rowset, using the
   660  // given comparison operator.
   661  func (c *CustomFuncs) ConstructAnyCondition(
   662  	input memo.RelExpr, scalar opt.ScalarExpr, private *memo.SubqueryPrivate,
   663  ) opt.ScalarExpr {
   664  	inputVar := c.referenceSingleColumn(input)
   665  	return c.ConstructBinary(private.Cmp, scalar, inputVar)
   666  }
   667  
   668  // ConstructBinary builds a dynamic binary expression, given the binary
   669  // operator's type and its two arguments.
   670  func (c *CustomFuncs) ConstructBinary(op opt.Operator, left, right opt.ScalarExpr) opt.ScalarExpr {
   671  	return c.f.DynamicConstruct(op, left, right).(opt.ScalarExpr)
   672  }
   673  
   674  // ConstructNoColsRow returns a Values operator having a single row with zero
   675  // columns.
   676  func (c *CustomFuncs) ConstructNoColsRow() memo.RelExpr {
   677  	return c.f.ConstructValues(memo.ScalarListWithEmptyTuple, &memo.ValuesPrivate{
   678  		Cols: opt.ColList{},
   679  		ID:   c.f.Metadata().NextUniqueID(),
   680  	})
   681  }
   682  
   683  // referenceSingleColumn returns a Variable operator that refers to the one and
   684  // only column that is projected by the input expression.
   685  func (c *CustomFuncs) referenceSingleColumn(in memo.RelExpr) opt.ScalarExpr {
   686  	colID := in.Relational().OutputCols.SingleColumn()
   687  	return c.f.ConstructVariable(colID)
   688  }
   689  
   690  // subqueryHoister searches scalar expression trees looking for correlated
   691  // subqueries which will be pulled up and joined to a higher level relational
   692  // query. See the  hoistAll comment for more details on how this is done.
   693  type subqueryHoister struct {
   694  	c       *CustomFuncs
   695  	f       *Factory
   696  	mem     *memo.Memo
   697  	hoisted memo.RelExpr
   698  }
   699  
   700  func (r *subqueryHoister) init(c *CustomFuncs, input memo.RelExpr) {
   701  	r.c = c
   702  	r.f = c.f
   703  	r.mem = c.mem
   704  	r.hoisted = input
   705  }
   706  
   707  // input returns a single expression tree that contains the input expression
   708  // provided to the init method, but wrapped with any subqueries hoisted out of
   709  // the scalar expression tree. See the hoistAll comment for more details.
   710  func (r *subqueryHoister) input() memo.RelExpr {
   711  	return r.hoisted
   712  }
   713  
   714  // hoistAll searches the given subtree for each correlated Subquery, Exists, or
   715  // Any operator, and lifts its subquery operand out of the scalar context and
   716  // joins it with a higher-level relational expression. The original subquery
   717  // operand is replaced by a Variable operator that refers to the first (and
   718  // only) column of the hoisted relational expression.
   719  //
   720  // hoistAll returns the root of a new expression tree that incorporates the new
   721  // Variable operators. The hoisted subqueries can be accessed via the input
   722  // method. Each removed subquery wraps the one before, with the input query at
   723  // the base. Each subquery adds a single column to its input and uses a
   724  // JoinApply operator to ensure that it has no effect on the cardinality of its
   725  // input. For example:
   726  //
   727  //   SELECT *
   728  //   FROM xy
   729  //   WHERE
   730  //     (SELECT u FROM uv WHERE u=x LIMIT 1) IS NOT NULL
   731  //     OR EXISTS(SELECT * FROM jk WHERE j=x)
   732  //   =>
   733  //   SELECT xy.*
   734  //   FROM xy
   735  //   LEFT JOIN LATERAL (SELECT u FROM uv WHERE u=x LIMIT 1)
   736  //   ON True
   737  //   INNER JOIN LATERAL
   738  //   (
   739  //     SELECT (CONST_AGG(True) IS NOT NULL) AS exists FROM jk WHERE j=x
   740  //   )
   741  //   ON True
   742  //   WHERE u IS NOT NULL OR exists
   743  //
   744  // The choice of whether to use LeftJoinApply or InnerJoinApply depends on the
   745  // cardinality of the hoisted subquery. If zero rows can be returned from the
   746  // subquery, then LeftJoinApply must be used in order to preserve the
   747  // cardinality of the input expression. Otherwise, InnerJoinApply can be used
   748  // instead. In either case, the wrapped subquery must never return more than one
   749  // row, so as not to change the cardinality of the result.
   750  //
   751  // See the comments for constructGroupByExists and constructGroupByAny for more
   752  // details on how EXISTS and ANY subqueries are hoisted, including usage of the
   753  // CONST_AGG function.
   754  func (r *subqueryHoister) hoistAll(scalar opt.ScalarExpr) opt.ScalarExpr {
   755  	// Match correlated subqueries.
   756  	switch scalar.Op() {
   757  	case opt.SubqueryOp, opt.ExistsOp, opt.AnyOp, opt.ArrayFlattenOp:
   758  		subquery := scalar.Child(0).(memo.RelExpr)
   759  		if subquery.Relational().OuterCols.Empty() {
   760  			break
   761  		}
   762  
   763  		switch t := scalar.(type) {
   764  		case *memo.ExistsExpr:
   765  			subquery = r.constructGroupByExists(subquery)
   766  
   767  		case *memo.AnyExpr:
   768  			subquery = r.constructGroupByAny(t.Scalar, t.Cmp, t.Input)
   769  		}
   770  
   771  		// Hoist the subquery into a single expression that can be accessed via
   772  		// the subqueries method.
   773  		subqueryProps := subquery.Relational()
   774  		if subqueryProps.Cardinality.CanBeZero() {
   775  			// Zero cardinality allowed, so must use left outer join to preserve
   776  			// outer row (padded with nulls) in case the subquery returns zero rows.
   777  			r.hoisted = r.f.ConstructLeftJoinApply(r.hoisted, subquery, memo.TrueFilter, memo.EmptyJoinPrivate)
   778  		} else {
   779  			// Zero cardinality not allowed, so inner join suffices. Inner joins
   780  			// are preferable to left joins since null handling is much simpler
   781  			// and they allow the optimizer more choices.
   782  			r.hoisted = r.f.ConstructInnerJoinApply(r.hoisted, subquery, memo.TrueFilter, memo.EmptyJoinPrivate)
   783  		}
   784  
   785  		// Replace the Subquery operator with a Variable operator referring to
   786  		// the output column of the hoisted query.
   787  		var colID opt.ColumnID
   788  		switch t := scalar.(type) {
   789  		case *memo.ArrayFlattenExpr:
   790  			colID = t.RequestedCol
   791  		default:
   792  			colID = subqueryProps.OutputCols.SingleColumn()
   793  		}
   794  		return r.f.ConstructVariable(colID)
   795  	}
   796  
   797  	return r.f.Replace(scalar, func(nd opt.Expr) opt.Expr {
   798  		// Recursively hoist subqueries in each scalar child that contains them.
   799  		// Skip relational children, since only subquery scalar operators have a
   800  		// relational child, and either:
   801  		//
   802  		//   1. The child is correlated, and therefore was handled above by hoisting
   803  		//      and rewriting (and therefore won't ever get here),
   804  		//
   805  		//   2. Or the child is uncorrelated, and therefore should be skipped, since
   806  		//      uncorrelated subqueries are not hoisted.
   807  		//
   808  		if scalarChild, ok := nd.(opt.ScalarExpr); ok {
   809  			return r.hoistAll(scalarChild)
   810  		}
   811  		return nd
   812  	}).(opt.ScalarExpr)
   813  }
   814  
   815  // constructGroupByExists transforms a scalar Exists expression like this:
   816  //
   817  //   EXISTS(SELECT * FROM a WHERE a.x=b.x)
   818  //
   819  // into a scalar GroupBy expression that returns a one row, one column relation:
   820  //
   821  //   SELECT (CONST_AGG(True) IS NOT NULL) AS exists
   822  //   FROM (SELECT * FROM a WHERE a.x=b.x)
   823  //
   824  // The expression uses an internally-defined CONST_AGG aggregation function,
   825  // since it's able to short-circuit on the first non-null it encounters. The
   826  // above expression is equivalent to:
   827  //
   828  //   SELECT COUNT(True) > 0 FROM (SELECT * FROM a WHERE a.x=b.x)
   829  //
   830  // CONST_AGG (and COUNT) always return exactly one boolean value in the context
   831  // of a scalar GroupBy expression. Because its operand is always True, the only
   832  // way the final expression is False is when the input set is empty (since
   833  // CONST_AGG returns NULL, which IS NOT NULL maps to False).
   834  //
   835  // However, later on, the TryDecorrelateScalarGroupBy rule will push a left join
   836  // into the GroupBy, and null values produced by the join will flow into the
   837  // CONST_AGG which will need to be changed to a CONST_NOT_NULL_AGG (which is
   838  // defined to ignore those nulls so that its result will be unaffected).
   839  func (r *subqueryHoister) constructGroupByExists(subquery memo.RelExpr) memo.RelExpr {
   840  	trueColID := r.f.Metadata().AddColumn("true", types.Bool)
   841  	aggColID := r.f.Metadata().AddColumn("true_agg", types.Bool)
   842  	existsColID := r.f.Metadata().AddColumn("exists", types.Bool)
   843  
   844  	return r.f.ConstructProject(
   845  		r.f.ConstructScalarGroupBy(
   846  			r.f.ConstructProject(
   847  				subquery,
   848  				memo.ProjectionsExpr{r.f.ConstructProjectionsItem(memo.TrueSingleton, trueColID)},
   849  				opt.ColSet{},
   850  			),
   851  			memo.AggregationsExpr{r.f.ConstructAggregationsItem(
   852  				r.f.ConstructConstAgg(r.f.ConstructVariable(trueColID)),
   853  				aggColID,
   854  			)},
   855  			memo.EmptyGroupingPrivate,
   856  		),
   857  		memo.ProjectionsExpr{r.f.ConstructProjectionsItem(
   858  			r.f.ConstructIsNot(
   859  				r.f.ConstructVariable(aggColID),
   860  				memo.NullSingleton,
   861  			),
   862  			existsColID,
   863  		)},
   864  		opt.ColSet{},
   865  	)
   866  }
   867  
   868  // constructGroupByAny transforms a scalar Any expression like this:
   869  //
   870  //   z = ANY(SELECT x FROM xy)
   871  //
   872  // into a scalar GroupBy expression that returns a one row, one column relation
   873  // that is equivalent to this:
   874  //
   875  //   SELECT
   876  //     CASE
   877  //       WHEN bool_or(notnull) AND z IS NOT Null THEN True
   878  //       ELSE bool_or(notnull) IS NULL THEN False
   879  //       ELSE Null
   880  //     END
   881  //   FROM
   882  //   (
   883  //     SELECT x IS NOT Null AS notnull
   884  //     FROM xy
   885  //     WHERE (z=x) IS NOT False
   886  //   )
   887  //
   888  // BOOL_OR returns true if any input is true, else false if any input is false,
   889  // else null. This is a mismatch with ANY, which returns true if any input is
   890  // true, else null if any input is null, else false. In addition, the expression
   891  // needs to be easy to decorrelate, which means that the outer column reference
   892  // ("z" in the example) should not be part of a projection (since projections
   893  // are difficult to hoist above left joins). The following procedure solves the
   894  // mismatch between BOOL_OR and ANY, as well as avoids correlated projections:
   895  //
   896  //   1. Filter out false comparison rows with an initial filter. The result of
   897  //      ANY does not change, no matter how many false rows are added or removed.
   898  //      This step has the effect of mapping a set containing only false
   899  //      comparison rows to the empty set (which is desirable).
   900  //
   901  //   2. Step #1 leaves only true and null comparison rows. A null comparison row
   902  //      occurs when either the left or right comparison operand is null (Any
   903  //      only allows comparison operators that propagate nulls). Map each null
   904  //      row to a false row, but only in the case where the right operand is null
   905  //      (i.e. the operand that came from the subquery). The case where the left
   906  //      operand is null will be handled later.
   907  //
   908  //   3. Use the BOOL_OR aggregation function on the true/false values from step
   909  //      #2. If there is at least one true value, then BOOL_OR returns true. If
   910  //      there are no values (the empty set case), then BOOL_OR returns null.
   911  //      Because of the previous steps, this indicates that the original set
   912  //      contained only false values (or no values at all).
   913  //
   914  //   4. A True result from BOOL_OR is ambiguous. It could mean that the
   915  //      comparison returned true for one of the rows in the group. Or, it could
   916  //      mean that the left operand was null. The CASE statement ensures that
   917  //      True is only returned if the left operand was not null.
   918  //
   919  //   5. In addition, the CASE statement maps a null return value to false, and
   920  //      false to null. This matches ANY behavior.
   921  //
   922  // The following is a table showing the various interesting cases:
   923  //
   924  //         | subquery  | before        | after   | after
   925  //     z   | x values  | BOOL_OR       | BOOL_OR | CASE
   926  //   ------+-----------+---------------+---------+-------
   927  //     1   | (1)       | (true)        | true    | true
   928  //     1   | (1, null) | (true, false) | true    | true
   929  //     1   | (1, 2)    | (true)        | true    | true
   930  //     1   | (null)    | (false)       | false   | null
   931  //    null | (1)       | (true)        | true    | null
   932  //    null | (1, null) | (true, false) | true    | null
   933  //    null | (null)    | (false)       | false   | null
   934  //     2   | (1)       | (empty)       | null    | false
   935  //   *any* | (empty)   | (empty)       | null    | false
   936  //
   937  // It is important that the set given to BOOL_OR does not contain any null
   938  // values (the reason for step #2). Null is reserved for use by the
   939  // TryDecorrelateScalarGroupBy rule, which will push a left join into the
   940  // GroupBy. Null values produced by the left join will simply be ignored by
   941  // BOOL_OR, and so cannot be used for any other purpose.
   942  func (r *subqueryHoister) constructGroupByAny(
   943  	scalar opt.ScalarExpr, cmp opt.Operator, input memo.RelExpr,
   944  ) memo.RelExpr {
   945  	// When the scalar value is not a simple variable or constant expression,
   946  	// then cache its value using a projection, since it will be referenced
   947  	// multiple times.
   948  	if scalar.Op() != opt.VariableOp && !opt.IsConstValueOp(scalar) {
   949  		typ := scalar.DataType()
   950  		scalarColID := r.f.Metadata().AddColumn("scalar", typ)
   951  		r.hoisted = r.c.ProjectExtraCol(r.hoisted, scalar, scalarColID)
   952  		scalar = r.f.ConstructVariable(scalarColID)
   953  	}
   954  
   955  	inputVar := r.f.funcs.referenceSingleColumn(input)
   956  	notNullColID := r.f.Metadata().AddColumn("notnull", types.Bool)
   957  	aggColID := r.f.Metadata().AddColumn("bool_or", types.Bool)
   958  	aggVar := r.f.ConstructVariable(aggColID)
   959  	caseColID := r.f.Metadata().AddColumn("case", types.Bool)
   960  
   961  	return r.f.ConstructProject(
   962  		r.f.ConstructScalarGroupBy(
   963  			r.f.ConstructProject(
   964  				r.f.ConstructSelect(
   965  					input,
   966  					memo.FiltersExpr{r.f.ConstructFiltersItem(
   967  						r.f.ConstructIsNot(
   968  							r.f.funcs.ConstructBinary(cmp, scalar, inputVar),
   969  							memo.FalseSingleton,
   970  						),
   971  					)},
   972  				),
   973  				memo.ProjectionsExpr{r.f.ConstructProjectionsItem(
   974  					r.f.ConstructIsNot(inputVar, memo.NullSingleton),
   975  					notNullColID,
   976  				)},
   977  				opt.ColSet{},
   978  			),
   979  			memo.AggregationsExpr{r.f.ConstructAggregationsItem(
   980  				r.f.ConstructBoolOr(
   981  					r.f.ConstructVariable(notNullColID),
   982  				),
   983  				aggColID,
   984  			)},
   985  			memo.EmptyGroupingPrivate,
   986  		),
   987  		memo.ProjectionsExpr{r.f.ConstructProjectionsItem(
   988  			r.f.ConstructCase(
   989  				r.f.ConstructTrue(),
   990  				memo.ScalarListExpr{
   991  					r.f.ConstructWhen(
   992  						r.f.ConstructAnd(
   993  							aggVar,
   994  							r.f.ConstructIsNot(scalar, memo.NullSingleton),
   995  						),
   996  						r.f.ConstructTrue(),
   997  					),
   998  					r.f.ConstructWhen(
   999  						r.f.ConstructIs(aggVar, memo.NullSingleton),
  1000  						r.f.ConstructFalse(),
  1001  					),
  1002  				},
  1003  				memo.NullSingleton,
  1004  			),
  1005  			caseColID,
  1006  		)},
  1007  		opt.ColSet{},
  1008  	)
  1009  }