github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/memo/extract.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package memo
    12  
    13  import (
    14  	"github.com/cockroachdb/cockroach/pkg/sql/opt"
    15  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    16  	"github.com/cockroachdb/errors"
    17  )
    18  
    19  // This file contains various helper functions that extract useful information
    20  // from expressions.
    21  
    22  // CanExtractConstTuple returns true if the expression is a TupleOp with
    23  // constant values (a nested tuple of constant values is considered constant).
    24  func CanExtractConstTuple(e opt.Expr) bool {
    25  	return e.Op() == opt.TupleOp && CanExtractConstDatum(e)
    26  }
    27  
    28  // CanExtractConstDatum returns true if a constant datum can be created from the
    29  // given expression (tuples and arrays of constant values are considered
    30  // constant values). If CanExtractConstDatum returns true, then
    31  // ExtractConstDatum is guaranteed to work as well.
    32  func CanExtractConstDatum(e opt.Expr) bool {
    33  	if opt.IsConstValueOp(e) {
    34  		return true
    35  	}
    36  
    37  	if tup, ok := e.(*TupleExpr); ok {
    38  		for _, elem := range tup.Elems {
    39  			if !CanExtractConstDatum(elem) {
    40  				return false
    41  			}
    42  		}
    43  		return true
    44  	}
    45  
    46  	if arr, ok := e.(*ArrayExpr); ok {
    47  		for _, elem := range arr.Elems {
    48  			if !CanExtractConstDatum(elem) {
    49  				return false
    50  			}
    51  		}
    52  		return true
    53  	}
    54  
    55  	return false
    56  }
    57  
    58  // ExtractConstDatum returns the Datum that represents the value of an
    59  // expression with a constant value. An expression with a constant value is:
    60  //  - one that has a ConstValue tag, or
    61  //  - a tuple or array where all children are constant values.
    62  func ExtractConstDatum(e opt.Expr) tree.Datum {
    63  	switch t := e.(type) {
    64  	case *NullExpr:
    65  		return tree.DNull
    66  
    67  	case *TrueExpr:
    68  		return tree.DBoolTrue
    69  
    70  	case *FalseExpr:
    71  		return tree.DBoolFalse
    72  
    73  	case *ConstExpr:
    74  		return t.Value
    75  
    76  	case *TupleExpr:
    77  		datums := make(tree.Datums, len(t.Elems))
    78  		for i := range datums {
    79  			datums[i] = ExtractConstDatum(t.Elems[i])
    80  		}
    81  		return tree.NewDTuple(t.Typ, datums...)
    82  
    83  	case *ArrayExpr:
    84  		elementType := t.Typ.ArrayContents()
    85  		a := tree.NewDArray(elementType)
    86  		a.Array = make(tree.Datums, len(t.Elems))
    87  		for i := range a.Array {
    88  			a.Array[i] = ExtractConstDatum(t.Elems[i])
    89  			if a.Array[i] == tree.DNull {
    90  				a.HasNulls = true
    91  			} else {
    92  				a.HasNonNulls = true
    93  			}
    94  		}
    95  		return a
    96  	}
    97  	panic(errors.AssertionFailedf("non-const expression: %+v", e))
    98  }
    99  
   100  // ExtractAggFunc digs down into the given aggregate expression and returns the
   101  // aggregate function, skipping past any AggFilter or AggDistinct operators.
   102  func ExtractAggFunc(e opt.ScalarExpr) opt.ScalarExpr {
   103  	if filter, ok := e.(*AggFilterExpr); ok {
   104  		e = filter.Input
   105  	}
   106  
   107  	if distinct, ok := e.(*AggDistinctExpr); ok {
   108  		e = distinct.Input
   109  	}
   110  
   111  	if !opt.IsAggregateOp(e) {
   112  		panic(errors.AssertionFailedf("not an Aggregate"))
   113  	}
   114  
   115  	return e
   116  }
   117  
   118  // ExtractAggInputColumns returns the set of columns the aggregate depends on.
   119  func ExtractAggInputColumns(e opt.ScalarExpr) opt.ColSet {
   120  	var res opt.ColSet
   121  	if filter, ok := e.(*AggFilterExpr); ok {
   122  		res.Add(filter.Filter.(*VariableExpr).Col)
   123  		e = filter.Input
   124  	}
   125  
   126  	if distinct, ok := e.(*AggDistinctExpr); ok {
   127  		e = distinct.Input
   128  	}
   129  
   130  	if !opt.IsAggregateOp(e) {
   131  		panic(errors.AssertionFailedf("not an Aggregate"))
   132  	}
   133  
   134  	for i, n := 0, e.ChildCount(); i < n; i++ {
   135  		if variable, ok := e.Child(i).(*VariableExpr); ok {
   136  			res.Add(variable.Col)
   137  		}
   138  	}
   139  
   140  	return res
   141  }
   142  
   143  // ExtractAggFirstVar is given an aggregate expression and returns the Variable
   144  // expression for the first argument, skipping past modifiers like AggDistinct.
   145  func ExtractAggFirstVar(e opt.ScalarExpr) *VariableExpr {
   146  	e = ExtractAggFunc(e)
   147  	if e.ChildCount() == 0 {
   148  		panic(errors.AssertionFailedf("aggregate does not have any arguments"))
   149  	}
   150  
   151  	if variable, ok := e.Child(0).(*VariableExpr); ok {
   152  		return variable
   153  	}
   154  
   155  	panic(errors.AssertionFailedf("first aggregate input is not a Variable"))
   156  }
   157  
   158  // ExtractJoinEqualityColumns returns pairs of columns (one from the left side,
   159  // one from the right side) which are constrained to be equal in a join (and
   160  // have equivalent types).
   161  func ExtractJoinEqualityColumns(
   162  	leftCols, rightCols opt.ColSet, on FiltersExpr,
   163  ) (leftEq opt.ColList, rightEq opt.ColList) {
   164  	for i := range on {
   165  		condition := on[i].Condition
   166  		ok, left, right := ExtractJoinEquality(leftCols, rightCols, condition)
   167  		if !ok {
   168  			continue
   169  		}
   170  		// Don't allow any column to show up twice.
   171  		// TODO(radu): need to figure out the right thing to do in cases
   172  		// like: left.a = right.a AND left.a = right.b
   173  		duplicate := false
   174  		for i := range leftEq {
   175  			if leftEq[i] == left || rightEq[i] == right {
   176  				duplicate = true
   177  				break
   178  			}
   179  		}
   180  		if !duplicate {
   181  			leftEq = append(leftEq, left)
   182  			rightEq = append(rightEq, right)
   183  		}
   184  	}
   185  	return leftEq, rightEq
   186  }
   187  
   188  // ExtractJoinEqualityFilters returns the filters containing pairs of columns
   189  // (one from the left side, one from the right side) which are constrained to
   190  // be equal in a join (and have equivalent types).
   191  func ExtractJoinEqualityFilters(leftCols, rightCols opt.ColSet, on FiltersExpr) FiltersExpr {
   192  	// We want to avoid allocating a new slice unless strictly necessary.
   193  	var newFilters FiltersExpr
   194  	for i := range on {
   195  		condition := on[i].Condition
   196  		ok, _, _ := ExtractJoinEquality(leftCols, rightCols, condition)
   197  		if ok {
   198  			if newFilters != nil {
   199  				newFilters = append(newFilters, on[i])
   200  			}
   201  		} else {
   202  			if newFilters == nil {
   203  				newFilters = make(FiltersExpr, i, len(on)-1)
   204  				copy(newFilters, on[:i])
   205  			}
   206  		}
   207  	}
   208  	if newFilters != nil {
   209  		return newFilters
   210  	}
   211  	return on
   212  }
   213  
   214  func isVarEquality(condition opt.ScalarExpr) (leftVar, rightVar *VariableExpr, ok bool) {
   215  	if eq, ok := condition.(*EqExpr); ok {
   216  		if leftVar, ok := eq.Left.(*VariableExpr); ok {
   217  			if rightVar, ok := eq.Right.(*VariableExpr); ok {
   218  				return leftVar, rightVar, true
   219  			}
   220  		}
   221  	}
   222  	return nil, nil, false
   223  }
   224  
   225  // ExtractJoinEquality returns true if the given condition is a simple equality
   226  // condition with two variables (e.g. a=b), where one of the variables (returned
   227  // as "left") is in the set of leftCols and the other (returned as "right") is
   228  // in the set of rightCols.
   229  func ExtractJoinEquality(
   230  	leftCols, rightCols opt.ColSet, condition opt.ScalarExpr,
   231  ) (ok bool, left, right opt.ColumnID) {
   232  	lvar, rvar, ok := isVarEquality(condition)
   233  	if !ok {
   234  		return false, 0, 0
   235  	}
   236  
   237  	// Don't allow mixed types (see #22519).
   238  	if !lvar.DataType().Equivalent(rvar.DataType()) {
   239  		return false, 0, 0
   240  	}
   241  
   242  	if leftCols.Contains(lvar.Col) && rightCols.Contains(rvar.Col) {
   243  		return true, lvar.Col, rvar.Col
   244  	}
   245  	if leftCols.Contains(rvar.Col) && rightCols.Contains(lvar.Col) {
   246  		return true, rvar.Col, lvar.Col
   247  	}
   248  
   249  	return false, 0, 0
   250  }
   251  
   252  // ExtractRemainingJoinFilters calculates the remaining ON condition after
   253  // removing equalities that are handled separately. The given function
   254  // determines if an equality is redundant. The result is empty if there are no
   255  // remaining conditions.
   256  func ExtractRemainingJoinFilters(on FiltersExpr, leftEq, rightEq opt.ColList) FiltersExpr {
   257  	var newFilters FiltersExpr
   258  	for i := range on {
   259  		leftVar, rightVar, ok := isVarEquality(on[i].Condition)
   260  		if ok {
   261  			a, b := leftVar.Col, rightVar.Col
   262  			found := false
   263  			for j := range leftEq {
   264  				if (a == leftEq[j] && b == rightEq[j]) ||
   265  					(a == rightEq[j] && b == leftEq[j]) {
   266  					found = true
   267  					break
   268  				}
   269  			}
   270  			if found {
   271  				// Skip this condition.
   272  				continue
   273  			}
   274  		}
   275  		if newFilters == nil {
   276  			newFilters = make(FiltersExpr, 0, len(on)-i)
   277  		}
   278  		newFilters = append(newFilters, on[i])
   279  	}
   280  	return newFilters
   281  }
   282  
   283  // ExtractConstColumns returns columns in the filters expression that have been
   284  // constrained to fixed values.
   285  func ExtractConstColumns(
   286  	on FiltersExpr, mem *Memo, evalCtx *tree.EvalContext,
   287  ) (fixedCols opt.ColSet) {
   288  	for i := range on {
   289  		scalar := on[i]
   290  		scalarProps := scalar.ScalarProps()
   291  		if scalarProps.Constraints != nil && !scalarProps.Constraints.IsUnconstrained() {
   292  			fixedCols.UnionWith(scalarProps.Constraints.ExtractConstCols(evalCtx))
   293  		}
   294  	}
   295  	return fixedCols
   296  }
   297  
   298  // ExtractValueForConstColumn returns the constant value of a column returned by
   299  // ExtractConstColumns.
   300  func ExtractValueForConstColumn(
   301  	on FiltersExpr, mem *Memo, evalCtx *tree.EvalContext, col opt.ColumnID,
   302  ) tree.Datum {
   303  	for i := range on {
   304  		scalar := on[i]
   305  		scalarProps := scalar.ScalarProps()
   306  		if scalarProps.Constraints != nil {
   307  			if val := scalarProps.Constraints.ExtractValueForConstCol(evalCtx, col); val != nil {
   308  				return val
   309  			}
   310  		}
   311  	}
   312  	return nil
   313  }