github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/optbuilder/union.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package optbuilder
    12  
    13  import (
    14  	"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
    15  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    16  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    17  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    18  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    19  )
    20  
    21  // buildUnionClause builds a set of memo groups that represent the given union
    22  // clause.
    23  //
    24  // See Builder.buildStmt for a description of the remaining input and
    25  // return values.
    26  func (b *Builder) buildUnionClause(
    27  	clause *tree.UnionClause, desiredTypes []*types.T, inScope *scope,
    28  ) (outScope *scope) {
    29  	leftScope := b.buildStmt(clause.Left, desiredTypes, inScope)
    30  	// Try to propagate types left-to-right, if we didn't already have desired
    31  	// types.
    32  	if len(desiredTypes) == 0 {
    33  		desiredTypes = leftScope.makeColumnTypes()
    34  	}
    35  	rightScope := b.buildStmt(clause.Right, desiredTypes, inScope)
    36  	return b.buildSetOp(clause.Type, clause.All, inScope, leftScope, rightScope)
    37  }
    38  
    39  func (b *Builder) buildSetOp(
    40  	typ tree.UnionType, all bool, inScope, leftScope, rightScope *scope,
    41  ) (outScope *scope) {
    42  	// Remove any hidden columns, as they are not included in the Union.
    43  	leftScope.removeHiddenCols()
    44  	rightScope.removeHiddenCols()
    45  
    46  	outScope = inScope.push()
    47  
    48  	// propagateTypesLeft/propagateTypesRight indicate whether we need to wrap
    49  	// the left/right side in a projection to cast some of the columns to the
    50  	// correct type.
    51  	// For example:
    52  	//   SELECT NULL UNION SELECT 1
    53  	// The type of NULL is unknown, and the type of 1 is int. We need to
    54  	// wrap the left side in a project operation with a Cast expression so the
    55  	// output column will have the correct type.
    56  	propagateTypesLeft, propagateTypesRight := b.checkTypesMatch(
    57  		leftScope, rightScope,
    58  		true, /* tolerateUnknownLeft */
    59  		true, /* tolerateUnknownRight */
    60  		typ.String(),
    61  	)
    62  
    63  	if propagateTypesLeft {
    64  		leftScope = b.propagateTypes(leftScope /* dst */, rightScope /* src */)
    65  	}
    66  	if propagateTypesRight {
    67  		rightScope = b.propagateTypes(rightScope /* dst */, leftScope /* src */)
    68  	}
    69  
    70  	// For UNION, we have to synthesize new output columns (because they contain
    71  	// values from both the left and right relations). This is not necessary for
    72  	// INTERSECT or EXCEPT, since these operations are basically filters on the
    73  	// left relation.
    74  	if typ == tree.UnionOp {
    75  		outScope.cols = make([]scopeColumn, 0, len(leftScope.cols))
    76  		for i := range leftScope.cols {
    77  			c := &leftScope.cols[i]
    78  			b.synthesizeColumn(outScope, string(c.name), c.typ, nil, nil /* scalar */)
    79  		}
    80  	} else {
    81  		outScope.appendColumnsFromScope(leftScope)
    82  	}
    83  
    84  	// Create the mapping between the left-side columns, right-side columns and
    85  	// new columns (if needed).
    86  	leftCols := colsToColList(leftScope.cols)
    87  	rightCols := colsToColList(rightScope.cols)
    88  	newCols := colsToColList(outScope.cols)
    89  
    90  	left := leftScope.expr.(memo.RelExpr)
    91  	right := rightScope.expr.(memo.RelExpr)
    92  	private := memo.SetPrivate{LeftCols: leftCols, RightCols: rightCols, OutCols: newCols}
    93  
    94  	if all {
    95  		switch typ {
    96  		case tree.UnionOp:
    97  			outScope.expr = b.factory.ConstructUnionAll(left, right, &private)
    98  		case tree.IntersectOp:
    99  			outScope.expr = b.factory.ConstructIntersectAll(left, right, &private)
   100  		case tree.ExceptOp:
   101  			outScope.expr = b.factory.ConstructExceptAll(left, right, &private)
   102  		}
   103  	} else {
   104  		switch typ {
   105  		case tree.UnionOp:
   106  			outScope.expr = b.factory.ConstructUnion(left, right, &private)
   107  		case tree.IntersectOp:
   108  			outScope.expr = b.factory.ConstructIntersect(left, right, &private)
   109  		case tree.ExceptOp:
   110  			outScope.expr = b.factory.ConstructExcept(left, right, &private)
   111  		}
   112  	}
   113  
   114  	return outScope
   115  }
   116  
   117  // checkTypesMatch is used when the columns must match between two scopes (e.g.
   118  // for a UNION). Throws an error if the scopes don't have the same number of
   119  // columns, or when column types don't match 1-1, except:
   120  //  - if tolerateUnknownLeft is set and the left column has Unknown type while
   121  //    the right has a known type (in this case it returns propagateToLeft=true).
   122  //  - if tolerateUnknownRight is set and the right column has Unknown type while
   123  //    the right has a known type (in this case it returns propagateToRight=true).
   124  //
   125  // clauseTag is used only in error messages.
   126  //
   127  // TODO(dan): This currently checks whether the types are exactly the same,
   128  // but Postgres is more lenient:
   129  // http://www.postgresql.org/docs/9.5/static/typeconv-union-case.html.
   130  func (b *Builder) checkTypesMatch(
   131  	leftScope, rightScope *scope,
   132  	tolerateUnknownLeft bool,
   133  	tolerateUnknownRight bool,
   134  	clauseTag string,
   135  ) (propagateToLeft, propagateToRight bool) {
   136  	if len(leftScope.cols) != len(rightScope.cols) {
   137  		panic(pgerror.Newf(
   138  			pgcode.Syntax,
   139  			"each %s query must have the same number of columns: %d vs %d",
   140  			clauseTag, len(leftScope.cols), len(rightScope.cols),
   141  		))
   142  	}
   143  
   144  	for i := range leftScope.cols {
   145  		l := &leftScope.cols[i]
   146  		r := &rightScope.cols[i]
   147  
   148  		if l.typ.Equivalent(r.typ) {
   149  			continue
   150  		}
   151  
   152  		// Note that Unknown types are equivalent so at this point at most one of
   153  		// the types can be Unknown.
   154  		if l.typ.Family() == types.UnknownFamily && tolerateUnknownLeft {
   155  			propagateToLeft = true
   156  			continue
   157  		}
   158  		if r.typ.Family() == types.UnknownFamily && tolerateUnknownRight {
   159  			propagateToRight = true
   160  			continue
   161  		}
   162  
   163  		panic(pgerror.Newf(
   164  			pgcode.DatatypeMismatch,
   165  			"%v types %s and %s cannot be matched", clauseTag, l.typ, r.typ,
   166  		))
   167  	}
   168  	return propagateToLeft, propagateToRight
   169  }
   170  
   171  // propagateTypes propagates the types of the source columns to the destination
   172  // columns by wrapping the destination in a Project operation. The Project
   173  // operation passes through columns that already have the correct type, and
   174  // creates cast expressions for those that don't.
   175  func (b *Builder) propagateTypes(dst, src *scope) *scope {
   176  	expr := dst.expr.(memo.RelExpr)
   177  	dstCols := dst.cols
   178  
   179  	dst = dst.push()
   180  	dst.cols = make([]scopeColumn, 0, len(dstCols))
   181  
   182  	for i := 0; i < len(dstCols); i++ {
   183  		dstType := dstCols[i].typ
   184  		srcType := src.cols[i].typ
   185  		if dstType.Family() == types.UnknownFamily && srcType.Family() != types.UnknownFamily {
   186  			// Create a new column which casts the old column to the correct type.
   187  			castExpr := b.factory.ConstructCast(b.factory.ConstructVariable(dstCols[i].id), srcType)
   188  			b.synthesizeColumn(dst, string(dstCols[i].name), srcType, nil /* expr */, castExpr)
   189  		} else {
   190  			// The column is already the correct type, so add it as a passthrough
   191  			// column.
   192  			dst.appendColumn(&dstCols[i])
   193  		}
   194  	}
   195  	dst.expr = b.constructProject(expr, dst.cols)
   196  	return dst
   197  }